"...git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "a96238547dcde03ec2ed5986de2ac1c94887ce08"
utils.py 16.2 KB
Newer Older
1
2
import itertools
import json
3
import linecache
4
import math
5
import os
6
import pickle
7
import socket
8
from logging import getLogger
9
from pathlib import Path
10
from typing import Callable, Dict, Iterable, List, Union
11

12
13
import git
import numpy as np
14
import torch
15
import torch.distributed as dist
16
from rouge_score import rouge_scorer, scoring
17
from sacrebleu import corpus_bleu
18
19
from torch import nn
from torch.utils.data import Dataset, Sampler
20

21
from transformers import BartTokenizer
22
from transformers.file_utils import cached_property
23

24

25
26
27
28
29
30
31
32
try:
    from fairseq.data.data_utils import batch_by_size

    FAIRSEQ_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
    FAIRSEQ_AVAILABLE = False


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
    """From fairseq"""
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.0)
        smooth_loss.masked_fill_(pad_mask, 0.0)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)

    nll_loss = nll_loss.sum()  # mean()? Scared to break other math.
    smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
51
    return loss, nll_loss
52
53


54
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
55
    """Only used by LegacyDataset"""
56
    extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
57
58
59
60
61
62
63
64
    return tokenizer(
        [line],
        max_length=max_length,
        padding="max_length" if pad_to_max_length else None,
        truncation=True,
        return_tensors=return_tensors,
        **extra_kw,
    )
65
66


67
68
def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
69
70
71
    return list(map(f, x))


72
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
73
    """Uses sacrebleu's corpus_bleu implementation."""
74
    return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
75
76


77
def trim_batch(
Lysandre's avatar
Lysandre committed
78
79
80
    input_ids,
    pad_token_id,
    attention_mask=None,
81
82
83
84
85
86
87
88
89
):
    """Remove columns that are populated exclusively by pad_token_id"""
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])


90
class AbstractSeq2SeqDataset(Dataset):
91
92
93
    def __init__(
        self,
        tokenizer,
94
        data_dir,
95
96
        max_source_length,
        max_target_length,
97
        type_path="train",
98
        n_obs=None,
99
100
        src_lang=None,
        tgt_lang=None,
101
        prefix="",
102
103
    ):
        super().__init__()
104
105
        self.src_file = Path(data_dir).joinpath(type_path + ".source")
        self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
106
107
108
109
110
111
112
        self.len_file = Path(data_dir).joinpath(type_path + ".len")
        if os.path.exists(self.len_file):
            self.src_lens = pickle_load(self.len_file)
            self.used_char_len = False
        else:
            self.src_lens = self.get_char_lens(self.src_file)
            self.used_char_len = True
113
114
115
116
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
        self.tokenizer = tokenizer
117
118
        self.prefix = prefix if prefix is not None else ""

119
        if n_obs is not None:
120
121
122
123
            self.src_lens = self.src_lens[:n_obs]
        self.pad_token_id = self.tokenizer.pad_token_id
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
124
        self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
125
126

    def __len__(self):
127
128
        return len(self.src_lens)

129
130
131
132
    @staticmethod
    def get_char_lens(data_file):
        return [len(x) for x in Path(data_file).open().readlines()]

133
134
135
136
137
    @cached_property
    def tgt_lens(self):
        """Length in characters of target documents"""
        return self.get_char_lens(self.tgt_file)

138
    def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
139
        if distributed:
140
            return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
141
        else:
142
            return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
        assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
        assert not self.used_char_len, "You must call  python make_len_file.py before calling make_dynamic_sampler"
        sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))

        def num_tokens_in_example(i):
            return min(self.src_lens[i], self.max_target_length)

        # call fairseq cython function
        batch_sampler: List[List[int]] = batch_by_size(
            sorted_indices,
            num_tokens_fn=num_tokens_in_example,
            max_tokens=max_tokens_per_batch,
            required_batch_size_multiple=64,
        )
        shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
        # move the largest batch to the front to OOM quickly (uses an approximation for padding)
        approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
        largest_batch_idx = np.argmax(approximate_toks_per_batch)
        shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
            shuffled_batches[largest_batch_idx],
            shuffled_batches[0],
        )
        return shuffled_batches

169
170
171
172
173
174
175
176
    def __getitem__(self, item):
        raise NotImplementedError("You must implement this")

    def collate_fn(self, batch):
        raise NotImplementedError("You must implement this")


class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
177
    def __getitem__(self, index) -> Dict[str, torch.Tensor]:
178
        """Call tokenizer on src and tgt_lines"""
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
        source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
        target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)

        source_ids = source_inputs["input_ids"].squeeze()
        target_ids = target_inputs["input_ids"].squeeze()
        src_mask = source_inputs["attention_mask"].squeeze()
        return {
            "input_ids": source_ids,
            "attention_mask": src_mask,
193
            "labels": target_ids,
194
        }
195

196
    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
197
198
        input_ids = torch.stack([x["input_ids"] for x in batch])
        masks = torch.stack([x["attention_mask"] for x in batch])
199
        target_ids = torch.stack([x["labels"] for x in batch])
200
        pad_token_id = self.pad_token_id
201
202
        y = trim_batch(target_ids, pad_token_id)
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
203
204
205
        batch = {
            "input_ids": source_ids,
            "attention_mask": source_mask,
206
            "labels": y,
207
        }
208
209
        return batch

210

211
class Seq2SeqDataset(AbstractSeq2SeqDataset):
212
    """A dataset that calls prepare_seq2seq_batch."""
213

214
215
216
217
218
219
    def __getitem__(self, index) -> Dict[str, str]:
        index = index + 1  # linecache starts at 1
        source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
        tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
        assert source_line, f"empty source line for index {index}"
        assert tgt_line, f"empty tgt line for index {index}"
220
        return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
221
222

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
223
        """Call prepare_seq2seq_batch."""
224
        batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
225
226
227
228
229
            [x["src_texts"] for x in batch],
            src_lang=self.src_lang,
            tgt_texts=[x["tgt_texts"] for x in batch],
            tgt_lang=self.tgt_lang,
            max_length=self.max_source_length,
230
            max_target_length=self.max_target_length,
231
232
            return_tensors="pt",
            add_prefix_space=self.add_prefix_space,
233
234
235
        ).data
        batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
        return batch_encoding
236
237
238
239
240


class SortishSampler(Sampler):
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

241
242
    def __init__(self, data, batch_size, shuffle=True):
        self.data, self.bs, self.shuffle = data, batch_size, shuffle
243
244
245
246
247

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
248
        return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
249
250


251
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
252
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."
253
254
    if not shuffle:
        return np.argsort(np.array(data) * -1)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

    def key_fn(i):
        return data[i]

    idxs = np.random.permutation(len(data))
    sz = bs * 50
    ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
    sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
    sz = bs
    ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
    max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
    ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0]  # then make sure it goes first.
    sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
    sort_idx = np.concatenate((ck_idx[0], sort_idx))
    return sort_idx


class DistributedSortishSampler(Sampler):
    """Copied from torch DistributedSampler"""

275
    def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
276
277
278
279
280
281
282
283
284
285
286
287
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
288
289
290
291
292
293
        if add_extra_examples:
            self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
            self.total_size = self.num_samples * self.num_replicas
        else:
            self.total_size = len(dataset)
            self.num_samples = len(self.available_indices)
294
        self.batch_size = batch_size
295
        self.add_extra_examples = add_extra_examples
296
        self.shuffle = shuffle
297
298
299
300
301

    def __iter__(self) -> Iterable:
        g = torch.Generator()
        g.manual_seed(self.epoch)

302
        sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
303
        sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
304
        indices = [self.available_indices[i] for i in sortish_indices]
305
306
307
        assert len(indices) == self.num_samples
        return iter(indices)

308
309
    @cached_property
    def available_indices(self) -> np.array:
310
311
312
313
314
315
316
317
318
319
320
321
322
        indices = list(range(len(self.dataset)))
        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size
        # subsample
        available_indices = indices[self.rank : self.total_size : self.num_replicas]
        return available_indices

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch
323
324


325
326
327
logger = getLogger(__name__)


328
def use_task_specific_params(model, task):
329
    """Update config with summarization specific params."""
330
    task_specific_params = model.config.task_specific_params
331

332
    if task_specific_params is not None:
333
334
335
        pars = task_specific_params.get(task, {})
        logger.info(f"using task specific params for {task}: {pars}")
        model.config.update(pars)
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353


def pickle_load(path):
    """pickle.load(path)"""
    with open(path, "rb") as f:
        return pickle.load(f)


def pickle_save(obj, path):
    """pickle.dump(obj, path)"""
    with open(path, "wb") as f:
        return pickle.dump(obj, f)


def flatten_list(summary_ids: List[List]):
    return [x for x in itertools.chain.from_iterable(summary_ids)]


354
355
def save_git_info(folder_path: str) -> None:
    """Save git information to output_dir/git_log.json"""
356
    repo_infos = get_git_info()
357
    save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
358

359

360
def save_json(content, path, indent=4, **json_dump_kwargs):
361
    with open(path, "w") as f:
362
        json.dump(content, f, indent=indent, **json_dump_kwargs)
363
364
365
366
367


def load_json(path):
    with open(path) as f:
        return json.load(f)
368
369
370
371
372
373
374
375


def get_git_info():
    repo = git.Repo(search_parent_directories=True)
    repo_infos = {
        "repo_id": str(repo),
        "repo_sha": str(repo.head.object.hexsha),
        "repo_branch": str(repo.active_branch),
376
        "hostname": str(socket.gethostname()),
377
378
379
380
381
382
383
    }
    return repo_infos


ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]


384
385
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
    scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
386
387
388
389
390
391
392
    aggregator = scoring.BootstrapAggregator()

    for reference_ln, output_ln in zip(reference_lns, output_lns):
        scores = scorer.score(reference_ln, output_ln)
        aggregator.add_scores(scores)

    result = aggregator.aggregate()
393
    return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
394
395


396
397
398
# Utilities for freezing parameters and checking whether they are frozen


399
def freeze_params(model: nn.Module):
400
    """Set requires_grad=False for each of model.parameters()"""
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    for par in model.parameters():
        par.requires_grad = False


def grad_status(model: nn.Module) -> Iterable:
    return (par.requires_grad for par in model.parameters())


def any_requires_grad(model: nn.Module) -> bool:
    return any(grad_status(model))


def assert_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    n_require_grad = sum(lmap(int, model_grads))
    npars = len(model_grads)
    assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"


def assert_not_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    npars = len(model_grads)
    assert any(model_grads), f"none of {npars} weights require grad"
424
425
426
427
428


# CLI Parsing utils


429
430
431
432
433
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
    """
    Parse an argv list of unspecified command line args to a dict.
    Assumes all values are either numeric or boolean in the form of true/false.
    """
434
435
436
437
438
439
    result = {}
    assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
    num_pairs = len(unparsed_args) // 2
    for pair_num in range(num_pairs):
        i = 2 * pair_num
        assert unparsed_args[i].startswith("--")
440
441
442
443
444
445
446
447
448
        if unparsed_args[i + 1].lower() == "true":
            value = True
        elif unparsed_args[i + 1].lower() == "false":
            value = False
        else:
            try:
                value = int(unparsed_args[i + 1])
            except ValueError:
                value = float(unparsed_args[i + 1])  # this can raise another informative ValueError
449
450
451

        result[unparsed_args[i][2:]] = value
    return result
452
453
454
455
456
457
458


def write_txt_file(ordered_tgt, path):
    f = Path(path).open("w")
    for ln in ordered_tgt:
        f.write(ln + "\n")
        f.flush()
459
460
461
462
463
464


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]