finetune.py 19 KB
Newer Older
1
2
3
4
5
import argparse
import glob
import logging
import os
import time
6
from collections import defaultdict
7
8
from pathlib import Path
from typing import Dict, List, Tuple
9

10
11
import numpy as np
import pytorch_lightning as pl
12
13
14
import torch
from torch.utils.data import DataLoader

15
from lightning_base import BaseTransformer, add_generic_args, generic_train
16
17
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
18
19
20


try:
21
    from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
22
    from .utils import (
23
        ROUGE_KEYS,
24
        LegacySeq2SeqDataset,
25
        Seq2SeqDataset,
26
        assert_all_frozen,
27
        calculate_bleu,
28
        calculate_rouge,
29
30
31
        flatten_list,
        freeze_params,
        get_git_info,
32
        label_smoothed_nll_loss,
33
34
35
36
37
        lmap,
        pickle_save,
        save_git_info,
        save_json,
        use_task_specific_params,
38
    )
39
except ImportError:
40
    from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
41
    from utils import (
42
        ROUGE_KEYS,
43
        LegacySeq2SeqDataset,
44
45
        Seq2SeqDataset,
        assert_all_frozen,
46
        calculate_bleu,
47
        calculate_rouge,
48
49
50
        flatten_list,
        freeze_params,
        get_git_info,
51
        label_smoothed_nll_loss,
52
53
54
55
56
        lmap,
        pickle_save,
        save_git_info,
        save_json,
        use_task_specific_params,
57
    )
58
59
60
61

logger = logging.getLogger(__name__)


62
63
64
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
65
    metric_names = ROUGE_KEYS
66
    default_val_metric = "rouge2"
67

68
    def __init__(self, hparams, **kwargs):
69
70
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
71
72
73
74
75
76
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")

77
78
79
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
80
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
81
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
82
        pickle_save(self.hparams, self.hparams_save_path)
83
        self.step_count = 0
84
        self.metrics = defaultdict(list)
85

86
87
88
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
89
            prefix=self.model.config.prefix or "",
90
        )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

106
107
108
109
        if self.hparams.sortish_sampler and self.hparams.gpus > 1:
            raise AssertionError("Sortish Sampler does not work for multigpu")
        if self.hparams.sortish_sampler and self.hparams.max_tokens_per_batch is not None:
            raise AssertionError("max tokens per batch and sortish sampler are incompatible.")
110
111
112
        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
113
114
115
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

116
        self.hparams.git_sha = get_git_info()["repo_sha"]
117
        self.num_workers = hparams.num_workers
118
        self.decoder_start_token_id = None  # default to config
119
120
121
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
122
123
124
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
125
126
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
        assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
127
128
129
130
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
131
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
132
133
134

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
135
        try:
136
137
138
139
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
140
        except AttributeError:
141
142
143
144
145
146
147
148
149
150
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
151
        )
152
        return lmap(str.strip, gen_text)
153

154
    def _step(self, batch: dict) -> Tuple:
155
        pad_token_id = self.tokenizer.pad_token_id
156
157
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
158
        if isinstance(self.model, T5ForConditionalGeneration):
159
            decoder_input_ids = self.model._shift_right(tgt_ids)
160
        else:
161
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
162

163
164
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
165
        if self.hparams.label_smoothing == 0:
166
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
167
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
168

169
            assert lm_logits.shape[-1] == self.model.config.vocab_size
170
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
171
        else:
172
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
173
            loss, nll_loss = label_smoothed_nll_loss(
174
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
175
            )
176
177
        return (loss,)

178
179
180
181
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

182
183
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
184

185
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
186
        # tokens per batch
187
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
188
189
190
191
        logs["bs"] = batch["input_ids"].shape[0]
        logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
        logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
        # TODO(SS): make a wandb summary metric for this
192
193
194
195
196
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

197
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
198
199
200
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
201
202
203
204
205
206
207
208
209
210
211
212
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
        self.save_metrics(all_metrics, prefix)  # writes to self.metrics_save_path
213
        preds = flatten_list([x["preds"] for x in outputs])
214
215
216
217
218
219
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
220
221
222
223

    def save_metrics(self, latest_metrics, type_path) -> None:
        self.metrics[type_path].append(latest_metrics)
        save_json(self.metrics, self.metrics_save_path)
224

225
226
    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
227

228
    def _generative_step(self, batch: dict) -> dict:
229
        t0 = time.time()
230
231

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
232
        generated_ids = self.model.generate(
233
234
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
235
236
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
237
            num_beams=self.eval_beams,
238
            max_length=self.eval_max_length,
239
        )
240
241
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
242
        target: List[str] = self.ids_to_clean_text(batch["labels"])
243
244
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
245
        rouge: Dict = self.calc_generative_metrics(preds, target)
246
        summ_len = np.mean(lmap(len, generated_ids))
247
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
248
        return base_metrics
249

250
251
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
252
253

    def test_epoch_end(self, outputs):
254
        return self.validation_epoch_end(outputs, prefix="test")
255

256
    def get_dataset(self, type_path) -> Seq2SeqDataset:
257
258
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
259
        dataset = self.dataset_class(
260
261
262
263
264
265
266
267
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

268
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
269
        dataset = self.get_dataset(type_path)
270
271

        if self.hparams.sortish_sampler and type_path != "test":
272
            sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=False,
                num_workers=self.num_workers,
                sampler=sampler,
            )

        elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
            batch_sampler = dataset.make_dynamic_sampler(
                self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
            )
            return DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=dataset.collate_fn,
                # shuffle=False,
                num_workers=self.num_workers,
                # batch_size=None,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=shuffle,
                num_workers=self.num_workers,
                sampler=None,
            )
303
304

    def train_dataloader(self) -> DataLoader:
305
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
306
307
        return dataloader

308
309
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
310

311
312
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
313
314
315
316

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
317
        add_generic_args(parser, root_dir)
318
        parser.add_argument(
319
            "--max_source_length",
320
321
322
323
324
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
325
326
327
328
329
330
331
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
349
        parser.add_argument("--max_tokens_per_batch", type=int, default=None)
350
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
351
352
353
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
354
355
356
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
357
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
358
359
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
360
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
361
362
363
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
364
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
365
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
366
367
368
369
370
371
372
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
373
374
375
        return parser


376
377
378
379
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
380
    default_val_metric = "bleu"
381

382
383
384
385
386
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

387
    def calc_generative_metrics(self, preds, target) -> dict:
388
        return calculate_bleu(preds, target)
389
390


391
392
393
394
395
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
396
        if "summarization" in args.task:
397
398
399
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
400
    dataset = Path(args.data_dir).name
401
    if (
402
        args.logger_name == "default"
403
404
405
406
407
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
408
    elif args.logger_name == "wandb":
409
410
        from pytorch_lightning.loggers import WandbLogger

411
412
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
413

414
    elif args.logger_name == "wandb_shared":
415
416
        from pytorch_lightning.loggers import WandbLogger

417
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
418
419
420
421
422

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
423
424

    lower_is_better = args.val_metric == "loss"
425
426
427
428
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
429
430
431
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
432
        early_stopping_callback=es_callback,
433
434
        logger=logger,
    )
435
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
436
437
438
439
440
441
442
443
444
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
445
446
447

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
448
    return model
449
450
451
452


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
453
    parser = pl.Trainer.add_argparse_args(parser)
454
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
455

456
457
458
    args = parser.parse_args()

    main(args)