finetune.py 18.4 KB
Newer Older
1
2
3
4
5
import argparse
import glob
import logging
import os
import time
6
from collections import defaultdict
7
8
from pathlib import Path
from typing import Dict, List, Tuple
9

10
11
import numpy as np
import pytorch_lightning as pl
12
13
14
import torch
from torch.utils.data import DataLoader

15
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
16
from lightning_base import BaseTransformer, add_generic_args, generic_train
17
18
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from utils import (
    ROUGE_KEYS,
    LegacySeq2SeqDataset,
    Seq2SeqDataset,
    assert_all_frozen,
    calculate_bleu,
    calculate_rouge,
    flatten_list,
    freeze_params,
    get_git_info,
    label_smoothed_nll_loss,
    lmap,
    pickle_save,
    save_git_info,
    save_json,
    use_task_specific_params,
)
36
37


38
39
40
logger = logging.getLogger(__name__)


41
42
43
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
44
    metric_names = ROUGE_KEYS
45
    default_val_metric = "rouge2"
46

47
    def __init__(self, hparams, **kwargs):
48
49
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
50
51
52
53
54
55
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")

56
57
58
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
59
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
60
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
61
        pickle_save(self.hparams, self.hparams_save_path)
62
        self.step_count = 0
63
        self.metrics = defaultdict(list)
64
65
        self.model_type = self.config.model_type
        self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
66

67
68
69
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
70
            prefix=self.model.config.prefix or "",
71
        )
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
89
90
91
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

92
        self.hparams.git_sha = get_git_info()["repo_sha"]
93
        self.num_workers = hparams.num_workers
94
        self.decoder_start_token_id = None  # default to config
95
96
97
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
98
99
100
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
101
102
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
        assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
103
104
105
106
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
107
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
108
109
110

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
111
112
113
114
115
        if self.model_type == "t5":
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)
        elif self.model_type == "fsmt":
116
117
118
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
119
120
121
122
        else:
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
123
124
125
126
127
128
129
130
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
131
        )
132
        return lmap(str.strip, gen_text)
133

134
    def _step(self, batch: dict) -> Tuple:
135
        pad_token_id = self.tokenizer.pad_token_id
136
137
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
138
        if isinstance(self.model, T5ForConditionalGeneration):
139
            decoder_input_ids = self.model._shift_right(tgt_ids)
140
        else:
141
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
142

143
144
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
145
        if self.hparams.label_smoothing == 0:
146
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
147
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
148

149
            assert lm_logits.shape[-1] == self.vocab_size
150
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
151
        else:
152
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
153
            loss, nll_loss = label_smoothed_nll_loss(
154
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
155
            )
156
157
        return (loss,)

158
159
160
161
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

162
163
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
164

165
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
166
        # tokens per batch
167
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
168
169
170
171
        logs["bs"] = batch["input_ids"].shape[0]
        logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
        logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
        # TODO(SS): make a wandb summary metric for this
172
173
174
175
176
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

177
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
178
179
180
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
181
182
183
184
185
186
187
188
189
190
191
192
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
        self.save_metrics(all_metrics, prefix)  # writes to self.metrics_save_path
193
        preds = flatten_list([x["preds"] for x in outputs])
194
195
196
197
198
199
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
200
201
202
203

    def save_metrics(self, latest_metrics, type_path) -> None:
        self.metrics[type_path].append(latest_metrics)
        save_json(self.metrics, self.metrics_save_path)
204

205
206
    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
207

208
    def _generative_step(self, batch: dict) -> dict:
209
        t0 = time.time()
210
211

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
212
        generated_ids = self.model.generate(
213
214
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
215
216
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
217
            num_beams=self.eval_beams,
218
            max_length=self.eval_max_length,
219
        )
220
221
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
222
        target: List[str] = self.ids_to_clean_text(batch["labels"])
223
224
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
225
        rouge: Dict = self.calc_generative_metrics(preds, target)
226
        summ_len = np.mean(lmap(len, generated_ids))
227
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
228
        return base_metrics
229

230
231
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
232
233

    def test_epoch_end(self, outputs):
234
        return self.validation_epoch_end(outputs, prefix="test")
235

236
    def get_dataset(self, type_path) -> Seq2SeqDataset:
237
238
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
239
        dataset = self.dataset_class(
240
241
242
243
244
245
246
247
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

248
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
249
        dataset = self.get_dataset(type_path)
250
251

        if self.hparams.sortish_sampler and type_path != "test":
252
            sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=False,
                num_workers=self.num_workers,
                sampler=sampler,
            )

        elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
            batch_sampler = dataset.make_dynamic_sampler(
                self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
            )
            return DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=dataset.collate_fn,
                # shuffle=False,
                num_workers=self.num_workers,
                # batch_size=None,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=shuffle,
                num_workers=self.num_workers,
                sampler=None,
            )
283
284

    def train_dataloader(self) -> DataLoader:
285
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
286
287
        return dataloader

288
289
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
290

291
292
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
293
294
295
296

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
297
        add_generic_args(parser, root_dir)
298
        parser.add_argument(
299
            "--max_source_length",
300
301
302
303
304
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
305
306
307
308
309
310
311
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
329
        parser.add_argument("--max_tokens_per_batch", type=int, default=None)
330
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
331
332
333
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
334
335
336
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
337
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
338
339
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
340
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
341
342
343
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
344
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
345
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
346
347
348
349
350
351
352
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
353
354
355
        return parser


356
357
358
359
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
360
    default_val_metric = "bleu"
361

362
363
364
365
366
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

367
    def calc_generative_metrics(self, preds, target) -> dict:
368
        return calculate_bleu(preds, target)
369
370


371
372
373
374
375
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
376
        if "summarization" in args.task:
377
378
379
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
380
    dataset = Path(args.data_dir).name
381
    if (
382
        args.logger_name == "default"
383
384
385
386
387
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
388
    elif args.logger_name == "wandb":
389
390
        from pytorch_lightning.loggers import WandbLogger

391
392
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
393

394
    elif args.logger_name == "wandb_shared":
395
396
        from pytorch_lightning.loggers import WandbLogger

397
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
398
399
400
401
402

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
403
404

    lower_is_better = args.val_metric == "loss"
405
406
407
408
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
409
410
411
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
412
        early_stopping_callback=es_callback,
413
414
        logger=logger,
    )
415
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
416
417
418
419
420
421
422
423
424
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
425
426
427

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
428
    return model
429
430
431
432


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
433
    parser = pl.Trainer.add_argparse_args(parser)
434
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
435

436
437
438
    args = parser.parse_args()

    main(args)