finetune.py 18.3 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
6
import argparse
import glob
import logging
import os
7
import sys
8
import time
9
from collections import defaultdict
10
11
from pathlib import Path
from typing import Dict, List, Tuple
12

13
14
import numpy as np
import pytorch_lightning as pl
15
16
17
import torch
from torch.utils.data import DataLoader

18
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
19
20
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
21
22
23
24
25
26
27
from utils import (
    ROUGE_KEYS,
    LegacySeq2SeqDataset,
    Seq2SeqDataset,
    assert_all_frozen,
    calculate_bleu,
    calculate_rouge,
28
    check_output_dir,
29
    flatten_list,
30
    freeze_embeds,
31
32
33
34
35
36
    freeze_params,
    get_git_info,
    label_smoothed_nll_loss,
    lmap,
    pickle_save,
    save_git_info,
37
    save_json,
38
39
    use_task_specific_params,
)
40
41


42
43
44
45
46
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train  # noqa


47
48
49
logger = logging.getLogger(__name__)


50
51
52
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
53
    metric_names = ROUGE_KEYS
54
    default_val_metric = "rouge2"
55

56
    def __init__(self, hparams, **kwargs):
57
58
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
59
60
61
62
63
64
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")

65
66
67
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
68
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
69
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
70
        pickle_save(self.hparams, self.hparams_save_path)
71
        self.step_count = 0
72
        self.metrics = defaultdict(list)
73
74
        self.model_type = self.config.model_type
        self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
75

76
77
78
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
79
            prefix=self.model.config.prefix or "",
80
        )
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
        if self.hparams.freeze_embeds:
96
            freeze_embeds(self.model)
97
        if self.hparams.freeze_encoder:
98
99
100
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

101
        self.hparams.git_sha = get_git_info()["repo_sha"]
102
        self.num_workers = hparams.num_workers
103
        self.decoder_start_token_id = None  # default to config
104
105
106
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
107
108
109
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
110
        self.already_saved_batch = False
111
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
112
113
114
115
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
116
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
117

118
119
120
121
122
123
124
125
126
127
128
    def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
        """A debugging utility"""
        readable_batch = {
            k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
        }
        save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
        save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")

        self.already_saved_batch = True
        return readable_batch

129
130
131
132
133
134
    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
135
        )
136
        return lmap(str.strip, gen_text)
137

138
    def _step(self, batch: dict) -> Tuple:
139
        pad_token_id = self.tokenizer.pad_token_id
140
141
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
142
        if isinstance(self.model, T5ForConditionalGeneration):
143
            decoder_input_ids = self.model._shift_right(tgt_ids)
144
        else:
145
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
146
147
148
        if not self.already_saved_batch:  # This would be slightly better if it only happened on rank zero
            batch["decoder_input_ids"] = decoder_input_ids
            self.save_readable_batch(batch)
149

150
151
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
152
        if self.hparams.label_smoothing == 0:
153
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
154
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
155

156
            assert lm_logits.shape[-1] == self.vocab_size
157
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
158
        else:
159
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
160
            loss, nll_loss = label_smoothed_nll_loss(
161
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
162
            )
163
164
        return (loss,)

165
166
167
168
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

169
170
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
171

172
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
173
        # tokens per batch
174
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
175
176
177
178
        logs["bs"] = batch["input_ids"].shape[0]
        logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
        logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
        # TODO(SS): make a wandb summary metric for this
179
180
181
182
183
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

184
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
185

186
187
188
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
189
190
191
192
193
194
195
196
197
198
199
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
200
        self.metrics[prefix].append(all_metrics)  # callback writes this to self.metrics_save_path
201
        preds = flatten_list([x["preds"] for x in outputs])
202
203
204
205
206
207
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
208
209
210

    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
211

212
    def _generative_step(self, batch: dict) -> dict:
213
        t0 = time.time()
214
215

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
216
        generated_ids = self.model.generate(
217
218
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
219
220
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
221
            num_beams=self.eval_beams,
222
            max_length=self.eval_max_length,
223
        )
224
225
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
226
        target: List[str] = self.ids_to_clean_text(batch["labels"])
227
228
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
229
        rouge: Dict = self.calc_generative_metrics(preds, target)
230
        summ_len = np.mean(lmap(len, generated_ids))
231
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
232
        return base_metrics
233

234
235
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
236
237

    def test_epoch_end(self, outputs):
238
        return self.validation_epoch_end(outputs, prefix="test")
239

240
    def get_dataset(self, type_path) -> Seq2SeqDataset:
241
242
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
243
        dataset = self.dataset_class(
244
245
246
247
248
249
250
251
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

252
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
253
        dataset = self.get_dataset(type_path)
254
255

        if self.hparams.sortish_sampler and type_path != "test":
256
            sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=False,
                num_workers=self.num_workers,
                sampler=sampler,
            )

        elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
            batch_sampler = dataset.make_dynamic_sampler(
                self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
            )
            return DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=dataset.collate_fn,
                # shuffle=False,
                num_workers=self.num_workers,
                # batch_size=None,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=shuffle,
                num_workers=self.num_workers,
                sampler=None,
            )
287
288

    def train_dataloader(self) -> DataLoader:
289
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
290
291
        return dataloader

292
293
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
294

295
296
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
297
298
299
300

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
301
        add_generic_args(parser, root_dir)
302
        parser.add_argument(
303
            "--max_source_length",
304
305
306
307
308
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
309
310
311
312
313
314
315
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
333
        parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
334
        parser.add_argument("--max_tokens_per_batch", type=int, default=None)
335
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
336
337
338
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
339
340
341
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
342
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
343
344
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
345
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
346
347
348
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
349
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
350
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
351
352
353
354
355
356
357
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
358
359
360
        return parser


361
362
363
364
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
365
    default_val_metric = "bleu"
366

367
368
369
370
371
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

372
    def calc_generative_metrics(self, preds, target) -> dict:
373
        return calculate_bleu(preds, target)
374
375


376
377
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
378
379
    check_output_dir(args, expected_items=3)

380
    if model is None:
381
        if "summarization" in args.task:
382
383
384
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
385
    dataset = Path(args.data_dir).name
386
    if (
387
        args.logger_name == "default"
388
389
390
391
392
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
393
    elif args.logger_name == "wandb":
394
395
        from pytorch_lightning.loggers import WandbLogger

396
397
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
398

399
    elif args.logger_name == "wandb_shared":
400
401
        from pytorch_lightning.loggers import WandbLogger

402
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
403
404
405
406
407

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
408
409

    lower_is_better = args.val_metric == "loss"
410
411
412
413
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
414
415
416
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
417
        early_stopping_callback=es_callback,
418
419
        logger=logger,
    )
420
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
421
422
423
424
425
426
427
428
429
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
430
431
432

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
433
    return model
434
435
436
437


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
438
    parser = pl.Trainer.add_argparse_args(parser)
439
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
440

441
442
443
    args = parser.parse_args()

    main(args)