utils.py 13.3 KB
Newer Older
1
2
import itertools
import json
3
import linecache
4
import math
5
import os
6
import pickle
7
from logging import getLogger
8
from pathlib import Path
9
from typing import Callable, Dict, Iterable, List, Union
10

11
12
import git
import numpy as np
13
import torch
14
import torch.distributed as dist
15
from rouge_score import rouge_scorer, scoring
16
from sacrebleu import corpus_bleu
17
18
from torch import nn
from torch.utils.data import Dataset, Sampler
19

20
21
from transformers import BartTokenizer

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
    """From fairseq"""
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.0)
        smooth_loss.masked_fill_(pad_mask, 0.0)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)

    nll_loss = nll_loss.sum()  # mean()? Scared to break other math.
    smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
41
    return loss, nll_loss
42
43


44
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
45
    """Only used by LegacyDataset"""
46
    extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
47
48
49
50
51
52
53
54
    return tokenizer(
        [line],
        max_length=max_length,
        padding="max_length" if pad_to_max_length else None,
        truncation=True,
        return_tensors=return_tensors,
        **extra_kw,
    )
55
56


57
58
def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
59
60
61
    return list(map(f, x))


62
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
63
    """Uses sacrebleu's corpus_bleu implementation."""
64
    return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
65
66


67
def trim_batch(
Lysandre's avatar
Lysandre committed
68
69
70
    input_ids,
    pad_token_id,
    attention_mask=None,
71
72
73
74
75
76
77
78
79
):
    """Remove columns that are populated exclusively by pad_token_id"""
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])


80
class AbstractSeq2SeqDataset(Dataset):
81
82
83
    def __init__(
        self,
        tokenizer,
84
        data_dir,
85
86
        max_source_length,
        max_target_length,
87
        type_path="train",
88
        n_obs=None,
89
90
        src_lang=None,
        tgt_lang=None,
91
        prefix="",
92
93
    ):
        super().__init__()
94
95
96
97
98
99
100
        self.src_file = Path(data_dir).joinpath(type_path + ".source")
        self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
        self.src_lens = self.get_char_lens(self.src_file)
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
        self.tokenizer = tokenizer
101
102
        self.prefix = prefix if prefix is not None else ""

103
        if n_obs is not None:
104
105
106
107
            self.src_lens = self.src_lens[:n_obs]
        self.pad_token_id = self.tokenizer.pad_token_id
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
108
        self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
109
110

    def __len__(self):
111
112
        return len(self.src_lens)

113
114
115
116
    @staticmethod
    def get_char_lens(data_file):
        return [len(x) for x in Path(data_file).open().readlines()]

117
118
119
120
121
    def make_sortish_sampler(self, batch_size, distributed=False):
        if distributed:
            return DistributedSortishSampler(self, batch_size)
        else:
            return SortishSampler(self.src_lens, batch_size)
122
123
124
125
126
127
128
129
130

    def __getitem__(self, item):
        raise NotImplementedError("You must implement this")

    def collate_fn(self, batch):
        raise NotImplementedError("You must implement this")


class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
131
    def __getitem__(self, index) -> Dict[str, torch.Tensor]:
132
        """Call tokenizer on src and tgt_lines"""
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
        target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)

        source_ids = source_inputs["input_ids"].squeeze()
        target_ids = target_inputs["input_ids"].squeeze()
        src_mask = source_inputs["attention_mask"].squeeze()
        return {
            "input_ids": source_ids,
            "attention_mask": src_mask,
147
            "labels": target_ids,
148
        }
149

150
    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
151
152
        input_ids = torch.stack([x["input_ids"] for x in batch])
        masks = torch.stack([x["attention_mask"] for x in batch])
153
        target_ids = torch.stack([x["labels"] for x in batch])
154
        pad_token_id = self.pad_token_id
155
156
        y = trim_batch(target_ids, pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
157
158
159
        batch = {
            "input_ids": source_ids,
            "attention_mask": source_mask,
160
            "labels": y,
161
        }
162
163
        return batch

164

165
class Seq2SeqDataset(AbstractSeq2SeqDataset):
166
    """A dataset that calls prepare_seq2seq_batch."""
167

168
169
170
171
172
173
174
    def __getitem__(self, index) -> Dict[str, str]:
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        return {
175
176
            "tgt_texts": tgt_line,
            "src_texts": source_line,
177
178
179
        }

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
180
        """Call prepare_seq2seq_batch."""
181
        batch_encoding = self.tokenizer.prepare_seq2seq_batch(
182
183
184
185
186
            [x["src_texts"] for x in batch],
            src_lang=self.src_lang,
            tgt_texts=[x["tgt_texts"] for x in batch],
            tgt_lang=self.tgt_lang,
            max_length=self.max_source_length,
187
            max_target_length=self.max_target_length,
188
189
            return_tensors="pt",
            add_prefix_space=self.add_prefix_space,
190
191
        )
        return batch_encoding.data
192
193
194
195
196
197
198
199
200
201
202
203


class SortishSampler(Sampler):
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

    def __init__(self, data, batch_size):
        self.data, self.bs = data, batch_size

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        return iter(sortish_sampler_indices(self.data, self.bs))


def sortish_sampler_indices(data: List, bs: int) -> np.array:
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

    def key_fn(i):
        return data[i]

    idxs = np.random.permutation(len(data))
    sz = bs * 50
    ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
    sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
    sz = bs
    ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
    max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
    ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0]  # then make sure it goes first.
    sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
    sort_idx = np.concatenate((ck_idx[0], sort_idx))
    return sort_idx


class DistributedSortishSampler(Sampler):
    """Copied from torch DistributedSampler"""

    def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.batch_size = batch_size

    def __iter__(self) -> Iterable:
        g = torch.Generator()
        g.manual_seed(self.epoch)
        available_indices = self.get_indices_for_rank()  # indices[self.rank: self.total_size: self.num_replicas]

        sortish_data = [self.dataset.src_lens[i] for i in available_indices]
        sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
        indices = [available_indices[i] for i in sortish_indices]
        assert len(indices) == self.num_samples
        return iter(indices)

    def get_indices_for_rank(self) -> np.array:
        indices = list(range(len(self.dataset)))
        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size
        # subsample
        available_indices = indices[self.rank : self.total_size : self.num_replicas]
        return available_indices

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch
271
272


273
274
275
logger = getLogger(__name__)


276
def use_task_specific_params(model, task):
277
    """Update config with summarization specific params."""
278
    task_specific_params = model.config.task_specific_params
279

280
    if task_specific_params is not None:
281
282
283
        pars = task_specific_params.get(task, {})
        logger.info(f"using task specific params for {task}: {pars}")
        model.config.update(pars)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301


def pickle_load(path):
    """pickle.load(path)"""
    with open(path, "rb") as f:
        return pickle.load(f)


def pickle_save(obj, path):
    """pickle.dump(obj, path)"""
    with open(path, "wb") as f:
        return pickle.dump(obj, f)


def flatten_list(summary_ids: List[List]):
    return [x for x in itertools.chain.from_iterable(summary_ids)]


302
303
def save_git_info(folder_path: str) -> None:
    """Save git information to output_dir/git_log.json"""
304
    repo_infos = get_git_info()
305
    save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
306

307
308
309
310
311
312
313
314
315

def save_json(content, path):
    with open(path, "w") as f:
        json.dump(content, f, indent=4)


def load_json(path):
    with open(path) as f:
        return json.load(f)
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330


def get_git_info():
    repo = git.Repo(search_parent_directories=True)
    repo_infos = {
        "repo_id": str(repo),
        "repo_sha": str(repo.head.object.hexsha),
        "repo_branch": str(repo.active_branch),
    }
    return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]


331
332
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
    scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
333
334
335
336
337
338
339
    aggregator = scoring.BootstrapAggregator()

    for reference_ln, output_ln in zip(reference_lns, output_lns):
        scores = scorer.score(reference_ln, output_ln)
        aggregator.add_scores(scores)

    result = aggregator.aggregate()
340
    return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
341
342


343
344
345
# Utilities for freezing parameters and checking whether they are frozen


346
def freeze_params(model: nn.Module):
347
    """Set requires_grad=False for each of model.parameters()"""
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    for par in model.parameters():
        par.requires_grad = False


def grad_status(model: nn.Module) -> Iterable:
    return (par.requires_grad for par in model.parameters())


def any_requires_grad(model: nn.Module) -> bool:
    return any(grad_status(model))


def assert_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    n_require_grad = sum(lmap(int, model_grads))
    npars = len(model_grads)
    assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"


def assert_not_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    npars = len(model_grads)
    assert any(model_grads), f"none of {npars} weights require grad"
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390


# CLI Parsing utils


def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
    """Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
    result = {}
    assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
    num_pairs = len(unparsed_args) // 2
    for pair_num in range(num_pairs):
        i = 2 * pair_num
        assert unparsed_args[i].startswith("--")
        try:
            value = int(unparsed_args[i + 1])
        except ValueError:
            value = float(unparsed_args[i + 1])  # this can raise another informative ValueError

        result[unparsed_args[i][2:]] = value
    return result
391
392
393
394
395
396
397


def write_txt_file(ordered_tgt, path):
    f = Path(path).open("w")
    for ln in ordered_tgt:
        f.write(ln + "\n")
        f.flush()