finetune.py 12.8 KB
Newer Older
1
2
3
4
5
import argparse
import glob
import logging
import os
import time
6
7
from pathlib import Path
from typing import Dict, List, Tuple
8

9
10
import numpy as np
import pytorch_lightning as pl
11
12
13
import torch
from torch.utils.data import DataLoader

14
15
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
16
17
18


try:
19
20
21
22
23
24
25
26
27
28
29
30
31
    from .utils import (
        use_task_specific_params,
        SummarizationDataset,
        lmap,
        flatten_list,
        pickle_save,
        save_git_info,
        freeze_params,
        calculate_rouge,
        get_git_info,
        ROUGE_KEYS,
    )
    from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
32
except ImportError:
33
34
35
36
37
38
39
40
41
42
43
44
45
    from utils import (
        use_task_specific_params,
        SummarizationDataset,
        lmap,
        flatten_list,
        pickle_save,
        save_git_info,
        freeze_params,
        calculate_rouge,
        get_git_info,
        ROUGE_KEYS,
    )
    from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
46
47
48
49

logger = logging.getLogger(__name__)


50
51
52
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
53

54
55
56
57
58
59
60
61
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        self.step_count = 0
        self.metrics = {"train": [], "val": [], "test": []}
62

63
64
65
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
66
            prefix=self.model.config.prefix or "",
67
        )
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
            freeze_params(self.model.model.encoder)  # TODO: this will break for t5
        self.hparams.git_sha = get_git_info()["repo_sha"]
88
        self.num_workers = hparams.num_workers
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
        if self.model.config.model_type == "bart":
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
        else:
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
108
        )
109
        return lmap(str.strip, gen_text)
110

111
    def _step(self, batch: dict) -> Tuple:
112
        pad_token_id = self.tokenizer.pad_token_id
113
        source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
114
115
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone()
116
        lm_labels[y[:, 1:] == pad_token_id] = -100
117
        outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
118
        loss = outputs[0]
119
120
121
122
123
124
125
126
127
128
        return (loss,)

    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

129
    def validation_epoch_end(self, outputs, prefix="val") -> Dict:
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
        rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
        rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
        rouges.update({k: v.item() for k, v in losses.items()})
        losses.update(rouges)
        metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        metrics["step_count"] = self.step_count
        self.save_metrics(metrics, prefix)  # writes to self.metrics_save_path
        preds = flatten_list([x["preds"] for x in outputs])
        return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}

    def save_metrics(self, metrics, prefix) -> None:
        self.metrics[prefix].append(metrics)
        pickle_save(self.metrics, self.metrics_save_path)

147
    def _generative_step(self, batch: dict) -> dict:
148
149
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
150
151
        t0 = time.time()
        generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
152
        gen_time = (time.time() - t0) / source_ids.shape[0]
153
154
155
156
157
158
159
160
        preds = self.ids_to_clean_text(generated_ids)
        target = self.ids_to_clean_text(y)
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        rouge: Dict = calculate_rouge(preds, target)
        summ_len = np.mean(lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
        return base_metrics
161

162
163
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
164
165

    def test_epoch_end(self, outputs):
166
        return self.validation_epoch_end(outputs, prefix="test")
167
168
169
170
171
172
173
174
175
176
177
178
179

    def get_dataset(self, type_path) -> SummarizationDataset:
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
        dataset = SummarizationDataset(
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

180
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        dataset = self.get_dataset(type_path)
        sampler = None
        if self.hparams.sortish_sampler and type_path == "train":
            assert self.hparams.gpus <= 1  # TODO: assert earlier
            sampler = dataset.make_sortish_sampler(batch_size)
            shuffle = False

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=shuffle,
            num_workers=self.num_workers,
            sampler=sampler,
        )
196
197
198
        return dataloader

    def train_dataloader(self) -> DataLoader:
199
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
200
        t_total = (
201
            (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
202
203
204
205
206
207
208
209
210
            // self.hparams.gradient_accumulation_steps
            * float(self.hparams.num_train_epochs)
        )
        scheduler = get_linear_schedule_with_warmup(
            self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
        )
        self.lr_scheduler = scheduler
        return dataloader

211
212
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
213

214
215
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
216
217
218
219

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
220
        add_generic_args(parser, root_dir)
221
        parser.add_argument(
222
            "--max_source_length",
223
224
225
226
227
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
228
229
230
231
232
233
234
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
249
250
251
252
        parser.add_argument(
            "--data_dir",
            type=str,
            required=True,
253
            help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
254
        )
255
256
257
258
259
260
261
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
        parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
262
263
264
        return parser


265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
        model: BaseTransformer = SummarizationModule(args)
    if (
        args.logger == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name)
    elif args.logger == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        # TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
        logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
        logger=logger,
        # TODO: early stopping callback seems messed up
    )
295
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
296
297
298
299
300
301
302
303
304
305
306
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
    trainer.test(model)  # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
    return model
307
308
309
310


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
311
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
312
313
314
    args = parser.parse_args()

    main(args)