"pytorch_transformers/tokenization_gpt2.py" did not exist on "870b734bfd2cc83e43b29050fba03709a0c5b539"
finetune.py 17.5 KB
Newer Older
1
2
3
4
5
import argparse
import glob
import logging
import os
import time
6
import warnings
7
from collections import defaultdict
8
9
from pathlib import Path
from typing import Dict, List, Tuple
10

11
12
import numpy as np
import pytorch_lightning as pl
13
14
15
import torch
from torch.utils.data import DataLoader

16
from lightning_base import BaseTransformer, add_generic_args, generic_train
17
18
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
19
20
21


try:
22
    from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
23
    from .utils import (
24
        ROUGE_KEYS,
25
        LegacySeq2SeqDataset,
26
        Seq2SeqDataset,
27
        assert_all_frozen,
28
        calculate_bleu,
29
        calculate_rouge,
30
31
32
        flatten_list,
        freeze_params,
        get_git_info,
33
        label_smoothed_nll_loss,
34
35
36
37
38
        lmap,
        pickle_save,
        save_git_info,
        save_json,
        use_task_specific_params,
39
    )
40
except ImportError:
41
    from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
42
    from utils import (
43
        ROUGE_KEYS,
44
        LegacySeq2SeqDataset,
45
46
        Seq2SeqDataset,
        assert_all_frozen,
47
        calculate_bleu,
48
        calculate_rouge,
49
50
51
        flatten_list,
        freeze_params,
        get_git_info,
52
        label_smoothed_nll_loss,
53
54
55
56
57
        lmap,
        pickle_save,
        save_git_info,
        save_json,
        use_task_specific_params,
58
    )
59
60
61
62

logger = logging.getLogger(__name__)


63
64
65
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
66
    metric_names = ROUGE_KEYS
67
    default_val_metric = "rouge2"
68

69
70
71
72
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
73
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
74
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
75
        pickle_save(self.hparams, self.hparams_save_path)
76
        self.step_count = 0
77
        self.metrics = defaultdict(list)
78

79
80
81
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
82
            prefix=self.model.config.prefix or "",
83
        )
84
85
86
87
88
89
90
91
92
93
94
95
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
96
97
98
        if self.hparams.sortish_sampler and self.hparams.gpus > 1:
            self.hparams.sortish_sampler = False
            warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
99
100
101
102
103
104
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
105
106
107
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

108
        self.hparams.git_sha = get_git_info()["repo_sha"]
109
        self.num_workers = hparams.num_workers
110
        self.decoder_start_token_id = None  # default to config
111
112
113
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
114
115
116
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
117
118
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
        assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
119
120
121
122
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
123
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
124
125
126

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
127
        try:
128
129
130
131
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
132
        except AttributeError:
133
134
135
136
137
138
139
140
141
142
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
143
        )
144
        return lmap(str.strip, gen_text)
145

146
    def _step(self, batch: dict) -> Tuple:
147
        pad_token_id = self.tokenizer.pad_token_id
148
149
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
150
        if isinstance(self.model, T5ForConditionalGeneration):
151
            decoder_input_ids = self.model._shift_right(tgt_ids)
152
        else:
153
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
154

155
156
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
157
        if self.hparams.label_smoothing == 0:
158
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
159
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
160

161
            assert lm_logits.shape[-1] == self.model.config.vocab_size
162
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
163
        else:
164
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
165
            loss, nll_loss = label_smoothed_nll_loss(
166
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
167
            )
168
169
        return (loss,)

170
171
172
173
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

174
175
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
176

177
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
178
        # tokens per batch
179
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
180
181
182
183
184
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

185
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
186
187
188
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
189
190
191
192
193
194
195
196
197
198
199
200
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
        self.save_metrics(all_metrics, prefix)  # writes to self.metrics_save_path
201
        preds = flatten_list([x["preds"] for x in outputs])
202
203
204
205
206
207
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
208
209
210
211

    def save_metrics(self, latest_metrics, type_path) -> None:
        self.metrics[type_path].append(latest_metrics)
        save_json(self.metrics, self.metrics_save_path)
212

213
214
    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
215

216
    def _generative_step(self, batch: dict) -> dict:
217
        t0 = time.time()
218
219

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
220
        generated_ids = self.model.generate(
221
222
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
223
224
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
225
            num_beams=self.eval_beams,
226
            max_length=self.eval_max_length,
227
        )
228
229
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
230
        target: List[str] = self.ids_to_clean_text(batch["labels"])
231
232
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
233
        rouge: Dict = self.calc_generative_metrics(preds, target)
234
        summ_len = np.mean(lmap(len, generated_ids))
235
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
236
        return base_metrics
237

238
239
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
240
241

    def test_epoch_end(self, outputs):
242
        return self.validation_epoch_end(outputs, prefix="test")
243

244
    def get_dataset(self, type_path) -> Seq2SeqDataset:
245
246
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
247
        dataset = self.dataset_class(
248
249
250
251
252
253
254
255
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

256
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
257
258
259
        dataset = self.get_dataset(type_path)
        sampler = None
        if self.hparams.sortish_sampler and type_path == "train":
260
            assert self.hparams.gpus <= 1  # this should never break because of the assertion in __init__
261
262
263
264
265
266
267
268
269
270
271
            sampler = dataset.make_sortish_sampler(batch_size)
            shuffle = False

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=shuffle,
            num_workers=self.num_workers,
            sampler=sampler,
        )
272
273
274
        return dataloader

    def train_dataloader(self) -> DataLoader:
275
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
276
277
        return dataloader

278
279
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
280

281
282
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
283
284
285
286

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
287
        add_generic_args(parser, root_dir)
288
        parser.add_argument(
289
            "--max_source_length",
290
291
292
293
294
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
295
296
297
298
299
300
301
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
319
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
320
321
322
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
323
324
325
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
326
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
327
328
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
329
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
330
331
332
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
333
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
334
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
335
336
337
338
339
340
341
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
342
343
344
        return parser


345
346
347
348
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
349
    default_val_metric = "bleu"
350

351
352
353
354
355
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

356
    def calc_generative_metrics(self, preds, target) -> dict:
357
        return calculate_bleu(preds, target)
358
359


360
361
362
363
364
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
365
        if "summarization" in args.task:
366
367
368
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
369
    dataset = Path(args.data_dir).name
370
    if (
371
        args.logger_name == "default"
372
373
374
375
376
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
377
    elif args.logger_name == "wandb":
378
379
        from pytorch_lightning.loggers import WandbLogger

380
381
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
382

383
    elif args.logger_name == "wandb_shared":
384
385
        from pytorch_lightning.loggers import WandbLogger

386
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
387
388
389
390
391

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
392
393

    lower_is_better = args.val_metric == "loss"
394
395
396
397
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
398
399
400
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
401
        early_stopping_callback=es_callback,
402
403
        logger=logger,
    )
404
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
405
406
407
408
409
410
411
412
413
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
414
415
416

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
417
    return model
418
419
420
421


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
422
    parser = pl.Trainer.add_argparse_args(parser)
423
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
424

425
426
427
    args = parser.parse_args()

    main(args)