utils.py 8.84 KB
Newer Older
1
2
import itertools
import json
3
import os
4
import pickle
5
from logging import getLogger
6
from pathlib import Path
7
from typing import Callable, Dict, Iterable, List
8

9
10
import git
import numpy as np
11
import torch
12
from rouge_score import rouge_scorer, scoring
13
from sacrebleu import corpus_bleu
14
15
16
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
17

18
19
from transformers import BartTokenizer

20
21
22
23
24
25
26
27
28
29
30

def encode_file(
    tokenizer,
    data_path,
    max_length,
    pad_to_max_length=True,
    return_tensors="pt",
    overwrite_cache=False,
    prefix="",
    tok_name="",
):
31
    extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
    if not overwrite_cache and cache_path.exists():
        try:
            examples = torch.load(cache_path)
            assert isinstance(examples, list)
            return examples

        except Exception:
            print(f"failed to load from {cache_path}, retokenizing {data_path}")
    data_path = Path(data_path)

    lns = lmap(str.strip, data_path.open().readlines())
    lns = [prefix + text for text in lns]
    assert lns, f"found empty file at {data_path}"
46
    examples = []
47
    for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
48
        tokenized = tokenizer(
49
            [text],
50
            max_length=max_length,
51
            padding="max_length" if pad_to_max_length else None,
52
            truncation=True,
53
            return_tensors=return_tensors,
54
            **extra_kw,
55
        )
56
        assert tokenized.input_ids.shape[1] == max_length
57
58
        examples.append(tokenized)
    torch.save(lmap(dict, examples), cache_path.open("wb"))
59
    return examples
60
61


62
63
def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
64
65
66
    return list(map(f, x))


67
68
69
def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
    """Uses sacrebleu's corpus_bleu implementation."""
    return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
70
71


72
73
74
75
76
77
78
79
80
81
82
def trim_batch(
    input_ids, pad_token_id, attention_mask=None,
):
    """Remove columns that are populated exclusively by pad_token_id"""
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])


83
84
85
86
class SummarizationDataset(Dataset):
    def __init__(
        self,
        tokenizer,
87
        data_dir,
88
89
90
        type_path="train",
        max_source_length=1024,
        max_target_length=56,
91
92
93
        n_obs=None,
        overwrite_cache=False,
        prefix="",
94
95
        src_lang=None,
        tgt_lang=None,
96
97
    ):
        super().__init__()
98
        # FIXME: the rstrip logic strips all the chars, it seems.
99
        tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
100
101
        if hasattr(tokenizer, "set_lang") and src_lang is not None:
            tokenizer.set_lang(src_lang)  # HACK: only applies to mbart
102
103
104
105
106
107
108
109
        self.source = encode_file(
            tokenizer,
            os.path.join(data_dir, type_path + ".source"),
            max_source_length,
            overwrite_cache=overwrite_cache,
            prefix=prefix,
            tok_name=tok_name,
        )
110
        tgt_path = os.path.join(data_dir, type_path + ".target")
111
        if hasattr(tokenizer, "set_lang"):
112
113
            assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
            tokenizer.set_lang(tgt_lang)  # HACK: only applies to mbart
114
115
116
117
118
119
120
        self.target = encode_file(
            tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
        )
        if n_obs is not None:
            self.source = self.source[:n_obs]
            self.target = self.target[:n_obs]
        self.pad_token_id = tokenizer.pad_token_id
121
122
123
124
125
126
127

    def __len__(self):
        return len(self.source)

    def __getitem__(self, index):
        source_ids = self.source[index]["input_ids"].squeeze()
        target_ids = self.target[index]["input_ids"].squeeze()
128
        src_mask = self.source[index]["attention_mask"].squeeze()
129
        return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
130
131
132

    @staticmethod
    def trim_seq2seq_batch(batch, pad_token_id):
133
134
        y = trim_batch(batch["decoder_input_ids"], pad_token_id)
        source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
135
136
        return source_ids, source_mask, y

137
138
139
140
141
    def collate_fn(self, batch) -> dict:
        input_ids = torch.stack([x["input_ids"] for x in batch])
        masks = torch.stack([x["attention_mask"] for x in batch])
        target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
        pad_token_id = self.pad_token_id
142
143
        y = trim_batch(target_ids, pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
144
145
146
147
        batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
        return batch

    def make_sortish_sampler(self, batch_size):
148
149
        lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
        return SortishSampler(lens, batch_size)
150
151
152
153
154
155
156
157
158


class SortishSampler(Sampler):
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

    def __init__(self, data, batch_size):
        self.data, self.bs = data, batch_size

    def key(self, i):
159
        return self.data[i]
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        idxs = np.random.permutation(len(self.data))
        sz = self.bs * 50
        ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
        sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
        sz = self.bs
        ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
        max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
        ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0]  # then make sure it goes first.
        sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
        sort_idx = np.concatenate((ck_idx[0], sort_idx))
        return iter(sort_idx)


178
179
180
logger = getLogger(__name__)


181
def use_task_specific_params(model, task):
182
    """Update config with summarization specific params."""
183
    task_specific_params = model.config.task_specific_params
184

185
    if task_specific_params is not None:
186
187
188
        pars = task_specific_params.get(task, {})
        logger.info(f"using task specific params for {task}: {pars}")
        model.config.update(pars)
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206


def pickle_load(path):
    """pickle.load(path)"""
    with open(path, "rb") as f:
        return pickle.load(f)


def pickle_save(obj, path):
    """pickle.dump(obj, path)"""
    with open(path, "wb") as f:
        return pickle.dump(obj, f)


def flatten_list(summary_ids: List[List]):
    return [x for x in itertools.chain.from_iterable(summary_ids)]


207
208
def save_git_info(folder_path: str) -> None:
    """Save git information to output_dir/git_log.json"""
209
    repo_infos = get_git_info()
210
    save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
211

212
213
214
215
216
217
218
219
220

def save_json(content, path):
    with open(path, "w") as f:
        json.dump(content, f, indent=4)


def load_json(path):
    with open(path) as f:
        return json.load(f)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


def get_git_info():
    repo = git.Repo(search_parent_directories=True)
    repo_infos = {
        "repo_id": str(repo),
        "repo_sha": str(repo.head.object.hexsha),
        "repo_branch": str(repo.active_branch),
    }
    return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]


236
237
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
    scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    aggregator = scoring.BootstrapAggregator()

    for reference_ln, output_ln in zip(reference_lns, output_lns):
        scores = scorer.score(reference_ln, output_ln)
        aggregator.add_scores(scores)

    result = aggregator.aggregate()
    return {k: v.mid.fmeasure for k, v in result.items()}


def freeze_params(model: nn.Module):
    for par in model.parameters():
        par.requires_grad = False


def grad_status(model: nn.Module) -> Iterable:
    return (par.requires_grad for par in model.parameters())


def any_requires_grad(model: nn.Module) -> bool:
    return any(grad_status(model))


def assert_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    n_require_grad = sum(lmap(int, model_grads))
    npars = len(model_grads)
    assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"


def assert_not_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    npars = len(model_grads)
    assert any(model_grads), f"none of {npars} weights require grad"