finetune.py 15.9 KB
Newer Older
1
2
3
4
5
import argparse
import glob
import logging
import os
import time
6
from collections import defaultdict
7
8
from pathlib import Path
from typing import Dict, List, Tuple
9

10
11
import numpy as np
import pytorch_lightning as pl
12
13
14
import torch
from torch.utils.data import DataLoader

15
from lightning_base import BaseTransformer, add_generic_args, generic_train
16
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
17
18
19


try:
20
    from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
21
    from .utils import (
22
23
24
        ROUGE_KEYS,
        Seq2SeqDataset,
        TranslationDataset,
25
        assert_all_frozen,
26
27
        calculate_bleu_score,
        calculate_rouge,
28
29
30
        flatten_list,
        freeze_params,
        get_git_info,
31
        label_smoothed_nll_loss,
32
33
34
35
36
        lmap,
        pickle_save,
        save_git_info,
        save_json,
        use_task_specific_params,
37
    )
38
except ImportError:
39
    from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
40
    from utils import (
41
        ROUGE_KEYS,
42
        Seq2SeqDataset,
43
        TranslationDataset,
44
        assert_all_frozen,
45
46
        calculate_bleu_score,
        calculate_rouge,
47
48
49
        flatten_list,
        freeze_params,
        get_git_info,
50
        label_smoothed_nll_loss,
51
52
53
54
55
        lmap,
        pickle_save,
        save_git_info,
        save_json,
        use_task_specific_params,
56
    )
57
58
59
60

logger = logging.getLogger(__name__)


61
62
63
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
64
65
    metric_names = ROUGE_KEYS
    val_metric = "rouge2"
66

67
68
69
70
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
71
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
72
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
73
        pickle_save(self.hparams, self.hparams_save_path)
74
        self.step_count = 0
75
        self.metrics = defaultdict(list)
76

77
78
79
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
80
            prefix=self.model.config.prefix or "",
81
        )
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
100
101
102
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

103
        self.hparams.git_sha = get_git_info()["repo_sha"]
104
        self.num_workers = hparams.num_workers
105
        self.decoder_start_token_id = None
106
107
108
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
109
110
        if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
            self.dataset_class = TranslationDataset
111
112
        else:
            self.dataset_class = Seq2SeqDataset
113
114
115

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
116
        try:
117
118
119
120
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
121
        except AttributeError:
122
123
124
125
126
127
128
129
130
131
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
132
        )
133
        return lmap(str.strip, gen_text)
134

135
    def _step(self, batch: dict) -> Tuple:
136
        pad_token_id = self.tokenizer.pad_token_id
137
        source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
138
139
140
141
142
143
144
145

        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(target_ids)
            lm_labels = target_ids
        else:
            decoder_input_ids = target_ids[:, :-1].contiguous()  # Why this line?
            lm_labels = target_ids[:, 1:].clone()  # why clone?

146
147
148
149
150
151
152
153
154
155
156
157
158
        outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)

        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
            lm_logits = outputs[0]
            assert lm_logits.shape[-1] == self.model.config.vocab_size
            loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
        else:
            lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
            )
159
160
        return (loss,)

161
162
163
164
    @property
    def pad(self) -> int:
        return self.tokenizer.pad_token_id

165
166
    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
167

168
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
169
170
        # tokens per batch
        logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum()
171
172
173
174
175
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

176
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
177
178
179
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
180
        rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
181
        rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
182
183
184
185
186
187
        rouges.update({k: v.item() for k, v in losses.items()})
        losses.update(rouges)
        metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        metrics["step_count"] = self.step_count
        self.save_metrics(metrics, prefix)  # writes to self.metrics_save_path
        preds = flatten_list([x["preds"] for x in outputs])
188
189
190
191
192
        return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}

    def save_metrics(self, latest_metrics, type_path) -> None:
        self.metrics[type_path].append(latest_metrics)
        save_json(self.metrics, self.metrics_save_path)
193

194
195
    def calc_generative_metrics(self, preds, target) -> Dict:
        return calculate_rouge(preds, target)
196

197
    def _generative_step(self, batch: dict) -> dict:
198
        t0 = time.time()
199
        generated_ids = self.model.generate(
200
201
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
202
203
204
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
        )
205
206
207
        gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
        preds: List[str] = self.ids_to_clean_text(generated_ids)
        target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
208
209
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
210
        rouge: Dict = self.calc_generative_metrics(preds, target)
211
        summ_len = np.mean(lmap(len, generated_ids))
212
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
213
        return base_metrics
214

215
216
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
217
218

    def test_epoch_end(self, outputs):
219
        return self.validation_epoch_end(outputs, prefix="test")
220

221
    def get_dataset(self, type_path) -> Seq2SeqDataset:
222
223
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
224
        dataset = self.dataset_class(
225
226
227
228
229
230
231
232
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

233
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        dataset = self.get_dataset(type_path)
        sampler = None
        if self.hparams.sortish_sampler and type_path == "train":
            assert self.hparams.gpus <= 1  # TODO: assert earlier
            sampler = dataset.make_sortish_sampler(batch_size)
            shuffle = False

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=shuffle,
            num_workers=self.num_workers,
            sampler=sampler,
        )
249
250
251
        return dataloader

    def train_dataloader(self) -> DataLoader:
252
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
253
254
        return dataloader

255
256
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
257

258
259
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
260
261
262
263

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
264
        add_generic_args(parser, root_dir)
265
        parser.add_argument(
266
            "--max_source_length",
267
268
269
270
271
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
272
273
274
275
276
277
278
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
296
        parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
297
298
299
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
300
301
302
        parser.add_argument(
            "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
        )
303
        parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
304
305
        parser.add_argument("--src_lang", type=str, default="", required=False)
        parser.add_argument("--tgt_lang", type=str, default="", required=False)
306
307
308
309
310
311
312
        parser.add_argument(
            "--early_stopping_patience",
            type=int,
            default=-1,
            required=False,
            help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
        )
313
314
315
        return parser


316
317
318
319
320
321
class TranslationModule(SummarizationModule):
    mode = "translation"
    loss_names = ["loss"]
    metric_names = ["bleu"]
    val_metric = "bleu"

322
323
324
325
326
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang

327
328
329
330
    def calc_generative_metrics(self, preds, target) -> dict:
        return calculate_bleu_score(preds, target)


331
332
333
334
335
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
336
337
338
339
        if args.task == "summarization":
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
340
341

    dataset = Path(args.data_dir).name
342
    if (
343
        args.logger_name == "default"
344
345
346
347
348
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
349
    elif args.logger_name == "wandb":
350
351
        from pytorch_lightning.loggers import WandbLogger

352
353
        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)
354

355
    elif args.logger_name == "wandb_shared":
356
357
        from pytorch_lightning.loggers import WandbLogger

358
        logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
359
360
361
362
363

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
    else:
        es_callback = False
364
365
366
367
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
368
        checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
369
        early_stopping_callback=es_callback,
370
371
372
        logger=logger,
        # TODO: early stopping callback seems messed up
    )
373
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
374
375
376
377
378
379
380
381
382
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
383
384
385

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
386
    return model
387
388
389
390


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
391
    parser = pl.Trainer.add_argparse_args(parser)
392
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
393

394
395
396
    args = parser.parse_args()

    main(args)