finetune.py 17.6 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
6
import argparse
import glob
import logging
import os
7
import sys
8
import time
9
from collections import defaultdict
10
11
from pathlib import Path
from typing import Dict, List, Tuple
12

13
14
import numpy as np
import pytorch_lightning as pl
15
16
17
import torch
from torch.utils.data import DataLoader

18
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
19
20
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
21
22
23
24
25
26
27
28
from utils import (
    ROUGE_KEYS,
    LegacySeq2SeqDataset,
    Seq2SeqDataset,
    assert_all_frozen,
    calculate_bleu,
    calculate_rouge,
    flatten_list,
29
    freeze_embeds,
30
31
32
33
34
35
36
37
    freeze_params,
    get_git_info,
    label_smoothed_nll_loss,
    lmap,
    pickle_save,
    save_git_info,
    use_task_specific_params,
)
38
39


40
41
42
43
44
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train  # noqa


45
46
47
logger = logging.getLogger(__name__)


48
49
50
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
51
    metric_names = ROUGE_KEYS
52
    default_val_metric = "rouge2"
53

54
    def __init__(self, hparams, **kwargs):
55
56
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
57
58
59
60
61
62
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")

63
64
65
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
66
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
67
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
68
        pickle_save(self.hparams, self.hparams_save_path)
69
        self.step_count = 0
70
        self.metrics = defaultdict(list)
71
72
        self.model_type = self.config.model_type
        self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
73

74
75
76
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
77
            prefix=self.model.config.prefix or "",
78
        )
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
        if self.hparams.freeze_embeds:
94
            freeze_embeds(self.model)
95
        if self.hparams.freeze_encoder:
96
97
98
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

99
        self.hparams.git_sha = get_git_info()["repo_sha"]
100
        self.num_workers = hparams.num_workers
101
        self.decoder_start_token_id = None  # default to config
102
103
104
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
105
106
107
        self.dataset_class = (
            Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
        )
108
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
109
110
111
112
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
113
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
114
115
116
117
118
119
120

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
121
        )
122
        return lmap(str.strip, gen_text)
123

124
    def _step(self, batch: dict) -> Tuple:
125
        pad_token_id = self.tokenizer.pad_token_id
126
127
        src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
        tgt_ids = batch["labels"]
128
        if isinstance(self.model, T5ForConditionalGeneration):
129
            decoder_input_ids = self.model._shift_right(tgt_ids)
130
        else:
131
            decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
132

133
134
        outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
        lm_logits = outputs[0]
135
        if self.hparams.label_smoothing == 0:
136
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
137
            ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
138

139
            assert lm_logits.shape[-1] == self.vocab_size
140
            loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
141
        else:
142
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
143
            loss, nll_loss = label_smoothed_nll_loss(
144
                lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
145
            )
146
147
        return (loss,)

148
149
150
151
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

152
153
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
154

155
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
156
        # tokens per batch
157
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
158
159
160
161
        logs["bs"] = batch["input_ids"].shape[0]
        logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
        logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
        # TODO(SS): make a wandb summary metric for this
162
163
164
165
166
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

167
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
168
169
170
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
171
172
173
174
175
176
177
178
179
180
181
        generative_metrics = {
            k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
        }
        metric_val = (
            generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
        )
        metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
        generative_metrics.update({k: v.item() for k, v in losses.items()})
        losses.update(generative_metrics)
        all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        all_metrics["step_count"] = self.step_count
182
        self.metrics[prefix].append(all_metrics)  # callback writes this to self.metrics_save_path
183
        preds = flatten_list([x["preds"] for x in outputs])
184
185
186
187
188
189
        return {
            "log": all_metrics,
            "preds": preds,
            f"{prefix}_loss": loss,
            f"{prefix}_{self.val_metric}": metric_tensor,
        }
190
191
192

    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
193

194
    def _generative_step(self, batch: dict) -> dict:
195
        t0 = time.time()
196
197

        # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
198
        generated_ids = self.model.generate(
199
200
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
201
202
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
203
            num_beams=self.eval_beams,
204
            max_length=self.eval_max_length,
205
        )
206
207
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
208
        target: List[str] = self.ids_to_clean_text(batch["labels"])
209
210
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
211
        rouge: Dict = self.calc_generative_metrics(preds, target)
212
        summ_len = np.mean(lmap(len, generated_ids))
213
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
214
        return base_metrics
215

216
217
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
218
219

    def test_epoch_end(self, outputs):
220
        return self.validation_epoch_end(outputs, prefix="test")
221

222
    def get_dataset(self, type_path) -> Seq2SeqDataset:
223
224
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
225
        dataset = self.dataset_class(
226
227
228
229
230
231
232
233
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

234
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
235
        dataset = self.get_dataset(type_path)
236
237

        if self.hparams.sortish_sampler and type_path != "test":
238
            sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=False,
                num_workers=self.num_workers,
                sampler=sampler,
            )

        elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
            batch_sampler = dataset.make_dynamic_sampler(
                self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
            )
            return DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=dataset.collate_fn,
                # shuffle=False,
                num_workers=self.num_workers,
                # batch_size=None,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                collate_fn=dataset.collate_fn,
                shuffle=shuffle,
                num_workers=self.num_workers,
                sampler=None,
            )
269
270

    def train_dataloader(self) -> DataLoader:
271
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
272
273
        return dataloader

274
275
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
276

277
278
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
279
280
281
282

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
283
        add_generic_args(parser, root_dir)
284
        parser.add_argument(
285
            "--max_source_length",
286
287
288
289
290
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
291
292
293
294
295
296
297
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
315
        parser.add_argument("--max_tokens_per_batch", type=int, default=None)
316
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
317
318
319
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
320
321
322
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
323
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
324
325
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
326
        parser.add_argument("--eval_beams", type=int, default=None, required=False)
327
328
329
        parser.add_argument(
            "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
        )
330
        parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
331
        parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
332
333
334
335
336
337
338
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
339
340
341
        return parser


342
343
344
345
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
346
    default_val_metric = "bleu"
347

348
349
350
351
352
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

353
    def calc_generative_metrics(self, preds, target) -> dict:
354
        return calculate_bleu(preds, target)
355
356


357
358
359
360
361
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
362
        if "summarization" in args.task:
363
364
365
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
366
    dataset = Path(args.data_dir).name
367
    if (
368
        args.logger_name == "default"
369
370
371
372
373
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
374
    elif args.logger_name == "wandb":
375
376
        from pytorch_lightning.loggers import WandbLogger

377
378
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
379

380
    elif args.logger_name == "wandb_shared":
381
382
        from pytorch_lightning.loggers import WandbLogger

383
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
384
385
386
387
388

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
389
390

    lower_is_better = args.val_metric == "loss"
391
392
393
394
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
395
396
397
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric, args.save_top_k, lower_is_better
        ),
398
        early_stopping_callback=es_callback,
399
400
        logger=logger,
    )
401
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
402
403
404
405
406
407
408
409
410
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
411
412
413

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
414
    return model
415
416
417
418


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
419
    parser = pl.Trainer.add_argparse_args(parser)
420
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
421

422
423
424
    args = parser.parse_args()

    main(args)