finetune.py 18.6 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
6
import argparse
import glob
import logging
import os
7
import sys
8
import time
9
from collections import defaultdict
10
11
from pathlib import Path
from typing import Dict, List, Tuple
12

13
14
import numpy as np
import pytorch_lightning as pl
15
import torch
16
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
17
from torch import nn
18
19
from torch.utils.data import DataLoader

20
from transformers import MBartTokenizer, T5ForConditionalGeneration
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.models.bart.modeling_bart import shift_tokens_right
22
23
24
25
26
27
28
from utils import (
    ROUGE_KEYS,
    LegacySeq2SeqDataset,
    Seq2SeqDataset,
    assert_all_frozen,
    calculate_bleu,
    calculate_rouge,
29
    check_output_dir,
30
    flatten_list,
31
    freeze_embeds,
32
33
34
35
36
37
    freeze_params,
    get_git_info,
    label_smoothed_nll_loss,
    lmap,
    pickle_save,
    save_git_info,
38
    save_json,
39
40
    use_task_specific_params,
)
41
42


43
44
45
46
47
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train  # noqa


48
49
50
logger = logging.getLogger(__name__)


51
52
53
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
54
    metric_names = ROUGE_KEYS
55
    default_val_metric = "rouge2"
56

57
    def __init__(self, hparams, **kwargs):
58
59
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
60
61
62
63
64
65
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")

66
67
68
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
69
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
70
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
71
        pickle_save(self.hparams, self.hparams_save_path)
72
        self.step_count = 0
73
        self.metrics = defaultdict(list)
74
75
        self.model_type = self.config.model_type
        self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
76

77
78
79
80
81
        self.dataset_kwargs: dict = {
            "data_dir": self.hparams.data_dir,
            "max_source_length": self.hparams.max_source_length,
            "prefix": self.model.config.prefix or "",
        }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
        if self.hparams.freeze_embeds:
97
            freeze_embeds(self.model)
98
        if self.hparams.freeze_encoder:
99
100
101
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

102
        self.hparams.git_sha = get_git_info()["repo_sha"]
103
        self.num_workers = hparams.num_workers
104
        self.decoder_start_token_id = None  # default to config
105
106
107
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
108
109
110
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
111
        self.already_saved_batch = False
112
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
113
114
115
116
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
117
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
118

119
120
121
122
123
124
125
126
127
128
129
    def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
        """A debugging utility"""
        readable_batch = {
            k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
        }
        save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
        save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")

        self.already_saved_batch = True
        return readable_batch

130
131
132
133
134
135
    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
136
        )
137
        return lmap(str.strip, gen_text)
138

139
    def _step(self, batch: dict) -> Tuple:
140
        pad_token_id = self.tokenizer.pad_token_id
141
142
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
143
        if isinstance(self.model, T5ForConditionalGeneration):
144
            decoder_input_ids = self.model._shift_right(tgt_ids)
145
        else:
146
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
147
148
149
        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)
150

151
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
152
        lm_logits = outputs["logits"]
153
        if self.hparams.label_smoothing == 0:
154
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
155
            ce_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
156

157
            assert lm_logits.shape[-1] == self.vocab_size
158
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
159
        else:
160
            lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
161
            loss, nll_loss = label_smoothed_nll_loss(
162
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
163
            )
164
165
        return (loss,)

166
167
168
169
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

170
171
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
172

173
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
174
        # tokens per batch
175
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
176
177
178
179
        logs["bs"] = batch["input_ids"].shape[0]
        logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
        logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
        # TODO(SS): make a wandb summary metric for this
180
181
182
183
184
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

185
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
186
187
188
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
189
190
191
192
193
194
195
196
197
198
199
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
200
        self.metrics[prefix].append(all_metrics)  # callback writes this to self.metrics_save_path
201
        preds = flatten_list([x["preds"] for x in outputs])
202
203
204
205
206
207
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
208
209
210

    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
211

212
    def _generative_step(self, batch: dict) -> dict:
213
        t0 = time.time()
214
215

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
216
        generated_ids = self.model.generate(
217
218
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
219
220
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
221
            num_beams=self.eval_beams,
222
            max_length=self.eval_max_length,
223
        )
224
225
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
226
        target: List[str] = self.ids_to_clean_text(batch["labels"])
227
228
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
229
        rouge: Dict = self.calc_generative_metrics(preds, target)
230
        summ_len = np.mean(lmap(len, generated_ids))
231
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
232
        return base_metrics
233

234
235
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
236
237

    def test_epoch_end(self, outputs):
238
        return self.validation_epoch_end(outputs, prefix="test")
239

240
    def get_dataset(self, type_path) -> Seq2SeqDataset:
241
242
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
243
        dataset = self.dataset_class(
244
245
246
247
248
249
250
251
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

252
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
253
        dataset = self.get_dataset(type_path)
254

255
        if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
256
            sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
257
258
259
260
261
262
263
264
265
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=False,
                num_workers=self.num_workers,
                sampler=sampler,
            )

266
        elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            batch_sampler = dataset.make_dynamic_sampler(
                self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
            )
            return DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=dataset.collate_fn,
                # shuffle=False,
                num_workers=self.num_workers,
                # batch_size=None,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=shuffle,
                num_workers=self.num_workers,
                sampler=None,
            )
287
288

    def train_dataloader(self) -> DataLoader:
289
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
290
291
        return dataloader

292
293
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
294

295
296
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
297
298
299
300

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
301
        add_generic_args(parser, root_dir)
302
        parser.add_argument(
303
            "--max_source_length",
304
305
            default=1024,
            type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
306
307
308
309
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
310
        )
311
312
313
314
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
315
316
317
318
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
319
        )
320
321
322
323
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
324
325
326
327
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
328
329
330
331
332
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
335
336
            help=(
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            ),
337
338
339
340
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
341
        parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
342
        parser.add_argument("--max_tokens_per_batch", type=int, default=None)
343
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
344
345
346
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
347
348
349
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
350
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
351
352
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
353
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
354
355
356
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
357
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
358
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
359
360
361
362
363
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
Sylvain Gugger's avatar
Sylvain Gugger committed
364
365
366
367
            help=(
                "-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So"
                " val_check_interval will effect it."
            ),
368
        )
369
370
371
        return parser


372
373
374
375
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
376
    default_val_metric = "bleu"
377

378
379
380
381
382
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

383
    def calc_generative_metrics(self, preds, target) -> dict:
384
        return calculate_bleu(preds, target)
385
386


387
388
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
389
390
    check_output_dir(args, expected_items=3)

391
    if model is None:
392
        if "summarization" in args.task:
393
394
395
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
396
    dataset = Path(args.data_dir).name
397
    if (
398
        args.logger_name == "default"
399
400
401
402
403
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
404
    elif args.logger_name == "wandb":
405
406
        from pytorch_lightning.loggers import WandbLogger

407
408
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
409

410
    elif args.logger_name == "wandb_shared":
411
412
        from pytorch_lightning.loggers import WandbLogger

413
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
414
415
416
417
418

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
419
420

    lower_is_better = args.val_metric == "loss"
421
422
423
424
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
425
426
427
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
428
        early_stopping_callback=es_callback,
429
430
        logger=logger,
    )
431
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
432
433
434
435
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
436
    checkpoints = sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
437
438
439
440
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
441
442
443

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
444
    return model
445
446
447
448


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
449
    parser = pl.Trainer.add_argparse_args(parser)
450
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
451

452
453
454
    args = parser.parse_args()

    main(args)