finetune.py 18.3 KB
Newer Older
1
2
3
4
5
import argparse
import glob
import logging
import os
import time
6
from collections import defaultdict
7
8
from pathlib import Path
from typing import Dict, List, Tuple
9

10
11
import numpy as np
import pytorch_lightning as pl
12
13
14
import torch
from torch.utils.data import DataLoader

15
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
16
from lightning_base import BaseTransformer, add_generic_args, generic_train
17
18
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from utils import (
    ROUGE_KEYS,
    LegacySeq2SeqDataset,
    Seq2SeqDataset,
    assert_all_frozen,
    calculate_bleu,
    calculate_rouge,
    flatten_list,
    freeze_params,
    get_git_info,
    label_smoothed_nll_loss,
    lmap,
    pickle_save,
    save_git_info,
    use_task_specific_params,
)
35
36


37
38
39
logger = logging.getLogger(__name__)


40
41
42
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
43
    metric_names = ROUGE_KEYS
44
    default_val_metric = "rouge2"
45

46
    def __init__(self, hparams, **kwargs):
47
48
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
49
50
51
52
53
54
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")

55
56
57
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
58
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
59
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
60
        pickle_save(self.hparams, self.hparams_save_path)
61
        self.step_count = 0
62
        self.metrics = defaultdict(list)
63
64
        self.model_type = self.config.model_type
        self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
65

66
67
68
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
69
            prefix=self.model.config.prefix or "",
70
        )
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
88
89
90
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

91
        self.hparams.git_sha = get_git_info()["repo_sha"]
92
        self.num_workers = hparams.num_workers
93
        self.decoder_start_token_id = None  # default to config
94
95
96
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
97
98
99
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
100
101
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
        assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
102
103
104
105
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
106
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
107
108
109

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
110
111
112
113
114
        if self.model_type == "t5":
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)
        elif self.model_type == "fsmt":
115
116
117
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
118
119
120
121
        else:
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
122
123
124
125
126
127
128
129
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
130
        )
131
        return lmap(str.strip, gen_text)
132

133
    def _step(self, batch: dict) -> Tuple:
134
        pad_token_id = self.tokenizer.pad_token_id
135
136
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
137
        if isinstance(self.model, T5ForConditionalGeneration):
138
            decoder_input_ids = self.model._shift_right(tgt_ids)
139
        else:
140
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
141

142
143
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
144
        if self.hparams.label_smoothing == 0:
145
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
146
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
147

148
            assert lm_logits.shape[-1] == self.vocab_size
149
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
150
        else:
151
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
152
            loss, nll_loss = label_smoothed_nll_loss(
153
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
154
            )
155
156
        return (loss,)

157
158
159
160
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

161
162
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
163

164
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
165
        # tokens per batch
166
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
167
168
169
170
        logs["bs"] = batch["input_ids"].shape[0]
        logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
        logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
        # TODO(SS): make a wandb summary metric for this
171
172
173
174
175
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

176
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
177
178
179
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
180
181
182
183
184
185
186
187
188
189
190
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
191
        self.metrics[prefix].append(all_metrics)  # callback writes this to self.metrics_save_path
192
        preds = flatten_list([x["preds"] for x in outputs])
193
194
195
196
197
198
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
199
200
201

    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
202

203
    def _generative_step(self, batch: dict) -> dict:
204
        t0 = time.time()
205
206

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
207
        generated_ids = self.model.generate(
208
209
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
210
211
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
212
            num_beams=self.eval_beams,
213
            max_length=self.eval_max_length,
214
        )
215
216
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
217
        target: List[str] = self.ids_to_clean_text(batch["labels"])
218
219
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
220
        rouge: Dict = self.calc_generative_metrics(preds, target)
221
        summ_len = np.mean(lmap(len, generated_ids))
222
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
223
        return base_metrics
224

225
226
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
227
228

    def test_epoch_end(self, outputs):
229
        return self.validation_epoch_end(outputs, prefix="test")
230

231
    def get_dataset(self, type_path) -> Seq2SeqDataset:
232
233
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
234
        dataset = self.dataset_class(
235
236
237
238
239
240
241
242
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

243
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
244
        dataset = self.get_dataset(type_path)
245
246

        if self.hparams.sortish_sampler and type_path != "test":
247
            sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=False,
                num_workers=self.num_workers,
                sampler=sampler,
            )

        elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
            batch_sampler = dataset.make_dynamic_sampler(
                self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
            )
            return DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=dataset.collate_fn,
                # shuffle=False,
                num_workers=self.num_workers,
                # batch_size=None,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=shuffle,
                num_workers=self.num_workers,
                sampler=None,
            )
278
279

    def train_dataloader(self) -> DataLoader:
280
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
281
282
        return dataloader

283
284
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
285

286
287
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
288
289
290
291

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
292
        add_generic_args(parser, root_dir)
293
        parser.add_argument(
294
            "--max_source_length",
295
296
297
298
299
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
300
301
302
303
304
305
306
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
324
        parser.add_argument("--max_tokens_per_batch", type=int, default=None)
325
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
326
327
328
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
329
330
331
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
332
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
333
334
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
335
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
336
337
338
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
339
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
340
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
341
342
343
344
345
346
347
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
348
349
350
        return parser


351
352
353
354
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
355
    default_val_metric = "bleu"
356

357
358
359
360
361
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

362
    def calc_generative_metrics(self, preds, target) -> dict:
363
        return calculate_bleu(preds, target)
364
365


366
367
368
369
370
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
371
        if "summarization" in args.task:
372
373
374
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
375
    dataset = Path(args.data_dir).name
376
    if (
377
        args.logger_name == "default"
378
379
380
381
382
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
383
    elif args.logger_name == "wandb":
384
385
        from pytorch_lightning.loggers import WandbLogger

386
387
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
388

389
    elif args.logger_name == "wandb_shared":
390
391
        from pytorch_lightning.loggers import WandbLogger

392
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
393
394
395
396
397

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
398
399

    lower_is_better = args.val_metric == "loss"
400
401
402
403
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
404
405
406
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
407
        early_stopping_callback=es_callback,
408
409
        logger=logger,
    )
410
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
411
412
413
414
415
416
417
418
419
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
420
421
422

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
423
    return model
424
425
426
427


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
428
    parser = pl.Trainer.add_argparse_args(parser)
429
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
430

431
432
433
    args = parser.parse_args()

    main(args)