"examples/vscode:/vscode.git/clone" did not exist on "4302ace5bd6a0dba6be90e580b4718e270384bb0"
utils.py 16 KB
Newer Older
1
2
import itertools
import json
3
import linecache
4
import math
5
import os
6
import pickle
7
from logging import getLogger
8
from pathlib import Path
9
from typing import Callable, Dict, Iterable, List, Union
10

11
12
import git
import numpy as np
13
import torch
14
import torch.distributed as dist
15
from rouge_score import rouge_scorer, scoring
16
from sacrebleu import corpus_bleu
17
18
from torch import nn
from torch.utils.data import Dataset, Sampler
19

20
from transformers import BartTokenizer
21
from transformers.file_utils import cached_property
22

23

24
25
26
27
28
29
30
31
try:
    from fairseq.data.data_utils import batch_by_size

    FAIRSEQ_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
    FAIRSEQ_AVAILABLE = False


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
    """From fairseq"""
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.0)
        smooth_loss.masked_fill_(pad_mask, 0.0)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)

    nll_loss = nll_loss.sum()  # mean()? Scared to break other math.
    smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
50
    return loss, nll_loss
51
52


53
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
54
    """Only used by LegacyDataset"""
55
    extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
56
57
58
59
60
61
62
63
    return tokenizer(
        [line],
        max_length=max_length,
        padding="max_length" if pad_to_max_length else None,
        truncation=True,
        return_tensors=return_tensors,
        **extra_kw,
    )
64
65


66
67
def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
68
69
70
    return list(map(f, x))


71
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
72
    """Uses sacrebleu's corpus_bleu implementation."""
73
    return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
74
75


76
def trim_batch(
Lysandre's avatar
Lysandre committed
77
78
79
    input_ids,
    pad_token_id,
    attention_mask=None,
80
81
82
83
84
85
86
87
88
):
    """Remove columns that are populated exclusively by pad_token_id"""
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])


89
class AbstractSeq2SeqDataset(Dataset):
90
91
92
    def __init__(
        self,
        tokenizer,
93
        data_dir,
94
95
        max_source_length,
        max_target_length,
96
        type_path="train",
97
        n_obs=None,
98
99
        src_lang=None,
        tgt_lang=None,
100
        prefix="",
101
102
    ):
        super().__init__()
103
104
        self.src_file = Path(data_dir).joinpath(type_path + ".source")
        self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
105
106
107
108
109
110
111
        self.len_file = Path(data_dir).joinpath(type_path + ".len")
        if os.path.exists(self.len_file):
            self.src_lens = pickle_load(self.len_file)
            self.used_char_len = False
        else:
            self.src_lens = self.get_char_lens(self.src_file)
            self.used_char_len = True
112
113
114
115
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
        self.tokenizer = tokenizer
116
117
        self.prefix = prefix if prefix is not None else ""

118
        if n_obs is not None:
119
120
121
122
            self.src_lens = self.src_lens[:n_obs]
        self.pad_token_id = self.tokenizer.pad_token_id
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
123
        self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
124
125

    def __len__(self):
126
127
        return len(self.src_lens)

128
129
130
131
    @staticmethod
    def get_char_lens(data_file):
        return [len(x) for x in Path(data_file).open().readlines()]

132
133
134
135
136
    @cached_property
    def tgt_lens(self):
        """Length in characters of target documents"""
        return self.get_char_lens(self.tgt_file)

137
    def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
138
        if distributed:
139
            return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
140
        else:
141
            return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
        assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
        assert not self.used_char_len, "You must call  python make_len_file.py before calling make_dynamic_sampler"
        sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))

        def num_tokens_in_example(i):
            return min(self.src_lens[i], self.max_target_length)

        # call fairseq cython function
        batch_sampler: List[List[int]] = batch_by_size(
            sorted_indices,
            num_tokens_fn=num_tokens_in_example,
            max_tokens=max_tokens_per_batch,
            required_batch_size_multiple=64,
        )
        shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
        # move the largest batch to the front to OOM quickly (uses an approximation for padding)
        approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
        largest_batch_idx = np.argmax(approximate_toks_per_batch)
        shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
            shuffled_batches[largest_batch_idx],
            shuffled_batches[0],
        )
        return shuffled_batches

168
169
170
171
172
173
174
175
    def __getitem__(self, item):
        raise NotImplementedError("You must implement this")

    def collate_fn(self, batch):
        raise NotImplementedError("You must implement this")


class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
176
    def __getitem__(self, index) -> Dict[str, torch.Tensor]:
177
        """Call tokenizer on src and tgt_lines"""
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
        target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)

        source_ids = source_inputs["input_ids"].squeeze()
        target_ids = target_inputs["input_ids"].squeeze()
        src_mask = source_inputs["attention_mask"].squeeze()
        return {
            "input_ids": source_ids,
            "attention_mask": src_mask,
192
            "labels": target_ids,
193
        }
194

195
    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
196
197
        input_ids = torch.stack([x["input_ids"] for x in batch])
        masks = torch.stack([x["attention_mask"] for x in batch])
198
        target_ids = torch.stack([x["labels"] for x in batch])
199
        pad_token_id = self.pad_token_id
200
201
        y = trim_batch(target_ids, pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
202
203
204
        batch = {
            "input_ids": source_ids,
            "attention_mask": source_mask,
205
            "labels": y,
206
        }
207
208
        return batch

209

210
class Seq2SeqDataset(AbstractSeq2SeqDataset):
211
    """A dataset that calls prepare_seq2seq_batch."""
212

213
214
215
216
217
218
    def __getitem__(self, index) -> Dict[str, str]:
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
219
        return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
220
221

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
222
        """Call prepare_seq2seq_batch."""
223
        batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
224
225
226
227
228
            [x["src_texts"] for x in batch],
            src_lang=self.src_lang,
            tgt_texts=[x["tgt_texts"] for x in batch],
            tgt_lang=self.tgt_lang,
            max_length=self.max_source_length,
229
            max_target_length=self.max_target_length,
230
231
            return_tensors="pt",
            add_prefix_space=self.add_prefix_space,
232
233
234
        ).data
        batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
        return batch_encoding
235
236
237
238
239


class SortishSampler(Sampler):
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

240
241
    def __init__(self, data, batch_size, shuffle=True):
        self.data, self.bs, self.shuffle = data, batch_size, shuffle
242
243
244
245
246

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
247
        return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
248
249


250
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
251
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."
252
253
    if not shuffle:
        return np.argsort(np.array(data) * -1)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    def key_fn(i):
        return data[i]

    idxs = np.random.permutation(len(data))
    sz = bs * 50
    ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
    sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
    sz = bs
    ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
    max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
    ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0]  # then make sure it goes first.
    sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
    sort_idx = np.concatenate((ck_idx[0], sort_idx))
    return sort_idx


class DistributedSortishSampler(Sampler):
    """Copied from torch DistributedSampler"""

274
    def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
275
276
277
278
279
280
281
282
283
284
285
286
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
287
288
289
290
291
292
        if add_extra_examples:
            self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
            self.total_size = self.num_samples * self.num_replicas
        else:
            self.total_size = len(dataset)
            self.num_samples = len(self.available_indices)
293
        self.batch_size = batch_size
294
        self.add_extra_examples = add_extra_examples
295
        self.shuffle = shuffle
296
297
298
299
300

    def __iter__(self) -> Iterable:
        g = torch.Generator()
        g.manual_seed(self.epoch)

301
        sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
302
        sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
303
        indices = [self.available_indices[i] for i in sortish_indices]
304
305
306
        assert len(indices) == self.num_samples
        return iter(indices)

307
308
    @cached_property
    def available_indices(self) -> np.array:
309
310
311
312
313
314
315
316
317
318
319
320
321
        indices = list(range(len(self.dataset)))
        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size
        # subsample
        available_indices = indices[self.rank : self.total_size : self.num_replicas]
        return available_indices

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch
322
323


324
325
326
logger = getLogger(__name__)


327
def use_task_specific_params(model, task):
328
    """Update config with summarization specific params."""
329
    task_specific_params = model.config.task_specific_params
330

331
    if task_specific_params is not None:
332
333
334
        pars = task_specific_params.get(task, {})
        logger.info(f"using task specific params for {task}: {pars}")
        model.config.update(pars)
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352


def pickle_load(path):
    """pickle.load(path)"""
    with open(path, "rb") as f:
        return pickle.load(f)


def pickle_save(obj, path):
    """pickle.dump(obj, path)"""
    with open(path, "wb") as f:
        return pickle.dump(obj, f)


def flatten_list(summary_ids: List[List]):
    return [x for x in itertools.chain.from_iterable(summary_ids)]


353
354
def save_git_info(folder_path: str) -> None:
    """Save git information to output_dir/git_log.json"""
355
    repo_infos = get_git_info()
356
    save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
357

358

359
def save_json(content, path, indent=4, **json_dump_kwargs):
360
    with open(path, "w") as f:
361
        json.dump(content, f, indent=indent, **json_dump_kwargs)
362
363
364
365
366


def load_json(path):
    with open(path) as f:
        return json.load(f)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381


def get_git_info():
    repo = git.Repo(search_parent_directories=True)
    repo_infos = {
        "repo_id": str(repo),
        "repo_sha": str(repo.head.object.hexsha),
        "repo_branch": str(repo.active_branch),
    }
    return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]


382
383
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
    scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
384
385
386
387
388
389
390
    aggregator = scoring.BootstrapAggregator()

    for reference_ln, output_ln in zip(reference_lns, output_lns):
        scores = scorer.score(reference_ln, output_ln)
        aggregator.add_scores(scores)

    result = aggregator.aggregate()
391
    return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
392
393


394
395
396
# Utilities for freezing parameters and checking whether they are frozen


397
def freeze_params(model: nn.Module):
398
    """Set requires_grad=False for each of model.parameters()"""
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    for par in model.parameters():
        par.requires_grad = False


def grad_status(model: nn.Module) -> Iterable:
    return (par.requires_grad for par in model.parameters())


def any_requires_grad(model: nn.Module) -> bool:
    return any(grad_status(model))


def assert_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    n_require_grad = sum(lmap(int, model_grads))
    npars = len(model_grads)
    assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"


def assert_not_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    npars = len(model_grads)
    assert any(model_grads), f"none of {npars} weights require grad"
422
423
424
425
426


# CLI Parsing utils


427
428
429
430
431
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
    """
    Parse an argv list of unspecified command line args to a dict.
    Assumes all values are either numeric or boolean in the form of true/false.
    """
432
433
434
435
436
437
    result = {}
    assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
    num_pairs = len(unparsed_args) // 2
    for pair_num in range(num_pairs):
        i = 2 * pair_num
        assert unparsed_args[i].startswith("--")
438
439
440
441
442
443
444
445
446
        if unparsed_args[i + 1].lower() == "true":
            value = True
        elif unparsed_args[i + 1].lower() == "false":
            value = False
        else:
            try:
                value = int(unparsed_args[i + 1])
            except ValueError:
                value = float(unparsed_args[i + 1])  # this can raise another informative ValueError
447
448
449

        result[unparsed_args[i][2:]] = value
    return result
450
451
452
453
454
455
456


def write_txt_file(ordered_tgt, path):
    f = Path(path).open("w")
    for ln in ordered_tgt:
        f.write(ln + "\n")
        f.flush()