finetune.py 18.4 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
6
import argparse
import glob
import logging
import os
7
import sys
8
import time
9
from collections import defaultdict
10
11
from pathlib import Path
from typing import Dict, List, Tuple
12

13
14
import numpy as np
import pytorch_lightning as pl
15
16
17
import torch
from torch.utils.data import DataLoader

18
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
19
20
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from utils import (
    ROUGE_KEYS,
    LegacySeq2SeqDataset,
    Seq2SeqDataset,
    assert_all_frozen,
    calculate_bleu,
    calculate_rouge,
    flatten_list,
    freeze_params,
    get_git_info,
    label_smoothed_nll_loss,
    lmap,
    pickle_save,
    save_git_info,
    use_task_specific_params,
)
37
38


39
40
41
42
43
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train  # noqa


44
45
46
logger = logging.getLogger(__name__)


47
48
49
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
50
    metric_names = ROUGE_KEYS
51
    default_val_metric = "rouge2"
52

53
    def __init__(self, hparams, **kwargs):
54
55
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
56
57
58
59
60
61
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")

62
63
64
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
65
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
66
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
67
        pickle_save(self.hparams, self.hparams_save_path)
68
        self.step_count = 0
69
        self.metrics = defaultdict(list)
70
71
        self.model_type = self.config.model_type
        self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
72

73
74
75
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
76
            prefix=self.model.config.prefix or "",
77
        )
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
95
96
97
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

98
        self.hparams.git_sha = get_git_info()["repo_sha"]
99
        self.num_workers = hparams.num_workers
100
        self.decoder_start_token_id = None  # default to config
101
102
103
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
104
105
106
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
107
108
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
        assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
109
110
111
112
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
113
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
114
115
116

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
117
118
119
120
121
        if self.model_type == "t5":
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)
        elif self.model_type == "fsmt":
122
123
124
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
125
126
127
128
        else:
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
129
130
131
132
133
134
135
136
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
137
        )
138
        return lmap(str.strip, gen_text)
139

140
    def _step(self, batch: dict) -> Tuple:
141
        pad_token_id = self.tokenizer.pad_token_id
142
143
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
144
        if isinstance(self.model, T5ForConditionalGeneration):
145
            decoder_input_ids = self.model._shift_right(tgt_ids)
146
        else:
147
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
148

149
150
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
151
        if self.hparams.label_smoothing == 0:
152
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
153
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
154

155
            assert lm_logits.shape[-1] == self.vocab_size
156
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
157
        else:
158
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
159
            loss, nll_loss = label_smoothed_nll_loss(
160
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
161
            )
162
163
        return (loss,)

164
165
166
167
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

168
169
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
170

171
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
172
        # tokens per batch
173
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
174
175
176
177
        logs["bs"] = batch["input_ids"].shape[0]
        logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
        logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
        # TODO(SS): make a wandb summary metric for this
178
179
180
181
182
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

183
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
184
185
186
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
187
188
189
190
191
192
193
194
195
196
197
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
198
        self.metrics[prefix].append(all_metrics)  # callback writes this to self.metrics_save_path
199
        preds = flatten_list([x["preds"] for x in outputs])
200
201
202
203
204
205
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
206
207
208

    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
209

210
    def _generative_step(self, batch: dict) -> dict:
211
        t0 = time.time()
212
213

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
214
        generated_ids = self.model.generate(
215
216
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
217
218
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
219
            num_beams=self.eval_beams,
220
            max_length=self.eval_max_length,
221
        )
222
223
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
224
        target: List[str] = self.ids_to_clean_text(batch["labels"])
225
226
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
227
        rouge: Dict = self.calc_generative_metrics(preds, target)
228
        summ_len = np.mean(lmap(len, generated_ids))
229
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
230
        return base_metrics
231

232
233
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
234
235

    def test_epoch_end(self, outputs):
236
        return self.validation_epoch_end(outputs, prefix="test")
237

238
    def get_dataset(self, type_path) -> Seq2SeqDataset:
239
240
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
241
        dataset = self.dataset_class(
242
243
244
245
246
247
248
249
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

250
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
251
        dataset = self.get_dataset(type_path)
252
253

        if self.hparams.sortish_sampler and type_path != "test":
254
            sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=False,
                num_workers=self.num_workers,
                sampler=sampler,
            )

        elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
            batch_sampler = dataset.make_dynamic_sampler(
                self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
            )
            return DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=dataset.collate_fn,
                # shuffle=False,
                num_workers=self.num_workers,
                # batch_size=None,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=shuffle,
                num_workers=self.num_workers,
                sampler=None,
            )
285
286

    def train_dataloader(self) -> DataLoader:
287
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
288
289
        return dataloader

290
291
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
292

293
294
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
295
296
297
298

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
299
        add_generic_args(parser, root_dir)
300
        parser.add_argument(
301
            "--max_source_length",
302
303
304
305
306
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
307
308
309
310
311
312
313
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
331
        parser.add_argument("--max_tokens_per_batch", type=int, default=None)
332
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
333
334
335
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
336
337
338
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
339
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
340
341
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
342
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
343
344
345
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
346
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
347
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
348
349
350
351
352
353
354
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
355
356
357
        return parser


358
359
360
361
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
362
    default_val_metric = "bleu"
363

364
365
366
367
368
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

369
    def calc_generative_metrics(self, preds, target) -> dict:
370
        return calculate_bleu(preds, target)
371
372


373
374
375
376
377
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
378
        if "summarization" in args.task:
379
380
381
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
382
    dataset = Path(args.data_dir).name
383
    if (
384
        args.logger_name == "default"
385
386
387
388
389
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
390
    elif args.logger_name == "wandb":
391
392
        from pytorch_lightning.loggers import WandbLogger

393
394
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
395

396
    elif args.logger_name == "wandb_shared":
397
398
        from pytorch_lightning.loggers import WandbLogger

399
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
400
401
402
403
404

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
405
406

    lower_is_better = args.val_metric == "loss"
407
408
409
410
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
411
412
413
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
414
        early_stopping_callback=es_callback,
415
416
        logger=logger,
    )
417
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
418
419
420
421
422
423
424
425
426
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
427
428
429

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
430
    return model
431
432
433
434


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
435
    parser = pl.Trainer.add_argparse_args(parser)
436
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
437

438
439
440
    args = parser.parse_args()

    main(args)