utils.py 11 KB
Newer Older
1
2
import itertools
import json
3
import linecache
4
import os
5
import pickle
6
from logging import getLogger
7
from pathlib import Path
8
from typing import Callable, Dict, Iterable, List, Union
9

10
11
import git
import numpy as np
12
import torch
13
from rouge_score import rouge_scorer, scoring
14
from sacrebleu import corpus_bleu
15
16
from torch import nn
from torch.utils.data import Dataset, Sampler
17

18
19
from transformers import BartTokenizer

20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
    """From fairseq"""
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.0)
        smooth_loss.masked_fill_(pad_mask, 0.0)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)

    nll_loss = nll_loss.sum()  # mean()? Scared to break other math.
    smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
39
    return loss, nll_loss
40
41


42
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
43
    """Only used by LegacyDataset"""
44
    extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
45
46
47
48
49
50
51
52
    return tokenizer(
        [line],
        max_length=max_length,
        padding="max_length" if pad_to_max_length else None,
        truncation=True,
        return_tensors=return_tensors,
        **extra_kw,
    )
53
54


55
56
def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
57
58
59
    return list(map(f, x))


60
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
61
    """Uses sacrebleu's corpus_bleu implementation."""
62
    return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
63
64


65
def trim_batch(
Lysandre's avatar
Lysandre committed
66
67
68
    input_ids,
    pad_token_id,
    attention_mask=None,
69
70
71
72
73
74
75
76
77
):
    """Remove columns that are populated exclusively by pad_token_id"""
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])


78
class AbstractSeq2SeqDataset(Dataset):
79
80
81
    def __init__(
        self,
        tokenizer,
82
        data_dir,
83
84
        max_source_length,
        max_target_length,
85
        type_path="train",
86
        n_obs=None,
87
88
        src_lang=None,
        tgt_lang=None,
89
        prefix="",
90
91
    ):
        super().__init__()
92
93
94
95
96
97
98
99
        self.src_file = Path(data_dir).joinpath(type_path + ".source")
        self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
        self.src_lens = self.get_char_lens(self.src_file)
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
        self.tokenizer = tokenizer
        self.prefix = prefix
100
        if n_obs is not None:
101
102
103
104
            self.src_lens = self.src_lens[:n_obs]
        self.pad_token_id = self.tokenizer.pad_token_id
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
105
        self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
106
107

    def __len__(self):
108
109
        return len(self.src_lens)

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    @staticmethod
    def get_char_lens(data_file):
        return [len(x) for x in Path(data_file).open().readlines()]

    def make_sortish_sampler(self, batch_size):
        return SortishSampler(self.src_lens, batch_size)

    def __getitem__(self, item):
        raise NotImplementedError("You must implement this")

    def collate_fn(self, batch):
        raise NotImplementedError("You must implement this")


class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
125
    def __getitem__(self, index) -> Dict[str, torch.Tensor]:
126
        """Call tokenizer on src and tgt_lines"""
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
        target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)

        source_ids = source_inputs["input_ids"].squeeze()
        target_ids = target_inputs["input_ids"].squeeze()
        src_mask = source_inputs["attention_mask"].squeeze()
        return {
            "input_ids": source_ids,
            "attention_mask": src_mask,
141
            "labels": target_ids,
142
        }
143

144
    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
145
146
        input_ids = torch.stack([x["input_ids"] for x in batch])
        masks = torch.stack([x["attention_mask"] for x in batch])
147
        target_ids = torch.stack([x["labels"] for x in batch])
148
        pad_token_id = self.pad_token_id
149
150
        y = trim_batch(target_ids, pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
151
152
153
        batch = {
            "input_ids": source_ids,
            "attention_mask": source_mask,
154
            "labels": y,
155
        }
156
157
        return batch

158

159
class Seq2SeqDataset(AbstractSeq2SeqDataset):
160
    """A dataset that calls prepare_seq2seq_batch."""
161

162
163
164
165
166
167
168
    def __getitem__(self, index) -> Dict[str, str]:
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        return {
169
170
            "tgt_texts": tgt_line,
            "src_texts": source_line,
171
172
173
        }

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
174
        """Call prepare_seq2seq_batch."""
175
        batch_encoding = self.tokenizer.prepare_seq2seq_batch(
176
177
178
179
180
            [x["src_texts"] for x in batch],
            src_lang=self.src_lang,
            tgt_texts=[x["tgt_texts"] for x in batch],
            tgt_lang=self.tgt_lang,
            max_length=self.max_source_length,
181
            max_target_length=self.max_target_length,
182
183
            return_tensors="pt",
            add_prefix_space=self.add_prefix_space,
184
185
        )
        return batch_encoding.data
186
187
188
189
190
191
192
193
194


class SortishSampler(Sampler):
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

    def __init__(self, data, batch_size):
        self.data, self.bs = data, batch_size

    def key(self, i):
195
        return self.data[i]
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        idxs = np.random.permutation(len(self.data))
        sz = self.bs * 50
        ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
        sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
        sz = self.bs
        ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
        max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
        ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0]  # then make sure it goes first.
        sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
        sort_idx = np.concatenate((ck_idx[0], sort_idx))
        return iter(sort_idx)


214
215
216
logger = getLogger(__name__)


217
def use_task_specific_params(model, task):
218
    """Update config with summarization specific params."""
219
    task_specific_params = model.config.task_specific_params
220

221
    if task_specific_params is not None:
222
223
224
        pars = task_specific_params.get(task, {})
        logger.info(f"using task specific params for {task}: {pars}")
        model.config.update(pars)
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242


def pickle_load(path):
    """pickle.load(path)"""
    with open(path, "rb") as f:
        return pickle.load(f)


def pickle_save(obj, path):
    """pickle.dump(obj, path)"""
    with open(path, "wb") as f:
        return pickle.dump(obj, f)


def flatten_list(summary_ids: List[List]):
    return [x for x in itertools.chain.from_iterable(summary_ids)]


243
244
def save_git_info(folder_path: str) -> None:
    """Save git information to output_dir/git_log.json"""
245
    repo_infos = get_git_info()
246
    save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
247

248
249
250
251
252
253
254
255
256

def save_json(content, path):
    with open(path, "w") as f:
        json.dump(content, f, indent=4)


def load_json(path):
    with open(path) as f:
        return json.load(f)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271


def get_git_info():
    repo = git.Repo(search_parent_directories=True)
    repo_infos = {
        "repo_id": str(repo),
        "repo_sha": str(repo.head.object.hexsha),
        "repo_branch": str(repo.active_branch),
    }
    return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]


272
273
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
    scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
274
275
276
277
278
279
280
    aggregator = scoring.BootstrapAggregator()

    for reference_ln, output_ln in zip(reference_lns, output_lns):
        scores = scorer.score(reference_ln, output_ln)
        aggregator.add_scores(scores)

    result = aggregator.aggregate()
281
    return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
282
283


284
285
286
# Utilities for freezing parameters and checking whether they are frozen


287
def freeze_params(model: nn.Module):
288
    """Set requires_grad=False for each of model.parameters()"""
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    for par in model.parameters():
        par.requires_grad = False


def grad_status(model: nn.Module) -> Iterable:
    return (par.requires_grad for par in model.parameters())


def any_requires_grad(model: nn.Module) -> bool:
    return any(grad_status(model))


def assert_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    n_require_grad = sum(lmap(int, model_grads))
    npars = len(model_grads)
    assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"


def assert_not_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    npars = len(model_grads)
    assert any(model_grads), f"none of {npars} weights require grad"
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331


# CLI Parsing utils


def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
    """Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
    result = {}
    assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
    num_pairs = len(unparsed_args) // 2
    for pair_num in range(num_pairs):
        i = 2 * pair_num
        assert unparsed_args[i].startswith("--")
        try:
            value = int(unparsed_args[i + 1])
        except ValueError:
            value = float(unparsed_args[i + 1])  # this can raise another informative ValueError

        result[unparsed_args[i][2:]] = value
    return result