utils.py 10 KB
Newer Older
1
2
import itertools
import json
3
import linecache
4
import os
5
import pickle
6
import warnings
7
from logging import getLogger
8
from pathlib import Path
9
from typing import Callable, Dict, Iterable, List
10

11
12
import git
import numpy as np
13
import torch
14
from rouge_score import rouge_scorer, scoring
15
from sacrebleu import corpus_bleu
16
17
from torch import nn
from torch.utils.data import Dataset, Sampler
18

19
20
from transformers import BartTokenizer

21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
    """From fairseq"""
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.0)
        smooth_loss.masked_fill_(pad_mask, 0.0)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)

    nll_loss = nll_loss.sum()  # mean()? Scared to break other math.
    smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
40
    return loss, nll_loss
41
42


43
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
44
    extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
45
46
47
48
49
50
51
52
    return tokenizer(
        [line],
        max_length=max_length,
        padding="max_length" if pad_to_max_length else None,
        truncation=True,
        return_tensors=return_tensors,
        **extra_kw,
    )
53
54


55
56
def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
57
58
59
    return list(map(f, x))


60
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
61
    """Uses sacrebleu's corpus_bleu implementation."""
62
    return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
63
64


65
def trim_batch(
Lysandre's avatar
Lysandre committed
66
67
68
    input_ids,
    pad_token_id,
    attention_mask=None,
69
70
71
72
73
74
75
76
77
):
    """Remove columns that are populated exclusively by pad_token_id"""
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])


78
class Seq2SeqDataset(Dataset):
79
80
81
    def __init__(
        self,
        tokenizer,
82
        data_dir,
83
84
        max_source_length,
        max_target_length,
85
        type_path="train",
86
        n_obs=None,
87
88
        src_lang=None,
        tgt_lang=None,
89
        prefix="",
90
91
    ):
        super().__init__()
92
93
94
95
96
97
98
99
        self.src_file = Path(data_dir).joinpath(type_path + ".source")
        self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
        self.src_lens = self.get_char_lens(self.src_file)
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
        self.tokenizer = tokenizer
        self.prefix = prefix
100
        if n_obs is not None:
101
102
103
104
            self.src_lens = self.src_lens[:n_obs]
        self.pad_token_id = self.tokenizer.pad_token_id
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
105
106

    def __len__(self):
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        return len(self.src_lens)

    def __getitem__(self, index) -> Dict[str, torch.Tensor]:
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
        target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)

        source_ids = source_inputs["input_ids"].squeeze()
        target_ids = target_inputs["input_ids"].squeeze()
        src_mask = source_inputs["attention_mask"].squeeze()
        return {
            "input_ids": source_ids,
            "attention_mask": src_mask,
            "decoder_input_ids": target_ids,
        }
126

127
128
129
    @staticmethod
    def get_char_lens(data_file):
        return [len(x) for x in Path(data_file).open().readlines()]
130

131
    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
132
133
134
135
        input_ids = torch.stack([x["input_ids"] for x in batch])
        masks = torch.stack([x["attention_mask"] for x in batch])
        target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
        pad_token_id = self.pad_token_id
136
137
        y = trim_batch(target_ids, pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
138
139
140
141
142
        batch = {
            "input_ids": source_ids,
            "attention_mask": source_mask,
            "decoder_input_ids": y,
        }
143
144
145
        return batch

    def make_sortish_sampler(self, batch_size):
146
147
148
        return SortishSampler(self.src_lens, batch_size)


149
class TranslationDataset(Seq2SeqDataset):
150
    """A dataset that calls prepare_seq2seq_batch."""
151

152
153
154
155
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if self.max_source_length != self.max_target_length:
            warnings.warn(
156
157
                f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
                f"Imbalanced sequence lengths may be undesired for translation tasks"
158
159
160
161
162
163
164
165
166
            )

    def __getitem__(self, index) -> Dict[str, str]:
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        return {
167
168
            "tgt_texts": tgt_line,
            "src_texts": source_line,
169
170
171
        }

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
172
        batch_encoding = self.tokenizer.prepare_seq2seq_batch(
173
174
175
176
177
            [x["src_texts"] for x in batch],
            src_lang=self.src_lang,
            tgt_texts=[x["tgt_texts"] for x in batch],
            tgt_lang=self.tgt_lang,
            max_length=self.max_source_length,
178
            max_target_length=self.max_target_length,
179
180
        )
        return batch_encoding.data
181
182
183
184
185
186
187
188
189


class SortishSampler(Sampler):
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

    def __init__(self, data, batch_size):
        self.data, self.bs = data, batch_size

    def key(self, i):
190
        return self.data[i]
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        idxs = np.random.permutation(len(self.data))
        sz = self.bs * 50
        ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
        sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
        sz = self.bs
        ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
        max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
        ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0]  # then make sure it goes first.
        sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
        sort_idx = np.concatenate((ck_idx[0], sort_idx))
        return iter(sort_idx)


209
210
211
logger = getLogger(__name__)


212
def use_task_specific_params(model, task):
213
    """Update config with summarization specific params."""
214
    task_specific_params = model.config.task_specific_params
215

216
    if task_specific_params is not None:
217
218
219
        pars = task_specific_params.get(task, {})
        logger.info(f"using task specific params for {task}: {pars}")
        model.config.update(pars)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237


def pickle_load(path):
    """pickle.load(path)"""
    with open(path, "rb") as f:
        return pickle.load(f)


def pickle_save(obj, path):
    """pickle.dump(obj, path)"""
    with open(path, "wb") as f:
        return pickle.dump(obj, f)


def flatten_list(summary_ids: List[List]):
    return [x for x in itertools.chain.from_iterable(summary_ids)]


238
239
def save_git_info(folder_path: str) -> None:
    """Save git information to output_dir/git_log.json"""
240
    repo_infos = get_git_info()
241
    save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
242

243
244
245
246
247
248
249
250
251

def save_json(content, path):
    with open(path, "w") as f:
        json.dump(content, f, indent=4)


def load_json(path):
    with open(path) as f:
        return json.load(f)
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266


def get_git_info():
    repo = git.Repo(search_parent_directories=True)
    repo_infos = {
        "repo_id": str(repo),
        "repo_sha": str(repo.head.object.hexsha),
        "repo_branch": str(repo.active_branch),
    }
    return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]


267
268
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
    scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
269
270
271
272
273
274
275
    aggregator = scoring.BootstrapAggregator()

    for reference_ln, output_ln in zip(reference_lns, output_lns):
        scores = scorer.score(reference_ln, output_ln)
        aggregator.add_scores(scores)

    result = aggregator.aggregate()
276
    return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302


def freeze_params(model: nn.Module):
    for par in model.parameters():
        par.requires_grad = False


def grad_status(model: nn.Module) -> Iterable:
    return (par.requires_grad for par in model.parameters())


def any_requires_grad(model: nn.Module) -> bool:
    return any(grad_status(model))


def assert_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    n_require_grad = sum(lmap(int, model_grads))
    npars = len(model_grads)
    assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"


def assert_not_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    npars = len(model_grads)
    assert any(model_grads), f"none of {npars} weights require grad"