finetune.py 13.6 KB
Newer Older
1
2
3
4
5
import argparse
import glob
import logging
import os
import time
6
7
from pathlib import Path
from typing import Dict, List, Tuple
8

9
10
import numpy as np
import pytorch_lightning as pl
11
12
13
import torch
from torch.utils.data import DataLoader

14
15
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
16
17
18


try:
19
20
21
22
23
24
25
26
27
28
29
30
31
    from .utils import (
        use_task_specific_params,
        SummarizationDataset,
        lmap,
        flatten_list,
        pickle_save,
        save_git_info,
        freeze_params,
        calculate_rouge,
        get_git_info,
        ROUGE_KEYS,
    )
    from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
32
except ImportError:
33
34
35
36
37
38
39
40
41
42
43
44
45
    from utils import (
        use_task_specific_params,
        SummarizationDataset,
        lmap,
        flatten_list,
        pickle_save,
        save_git_info,
        freeze_params,
        calculate_rouge,
        get_git_info,
        ROUGE_KEYS,
    )
    from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
46
47
48
49

logger = logging.getLogger(__name__)


50
51
52
class SummarizationModule(BaseTransformer):
    mode = "summarization"
    loss_names = ["loss"]
53

54
55
56
57
58
59
60
61
    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        self.step_count = 0
        self.metrics = {"train": [], "val": [], "test": []}
62

63
64
65
        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
66
            prefix=self.model.config.prefix or "",
67
        )
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
            freeze_params(self.model.model.encoder)  # TODO: this will break for t5
        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = 4 if self.hparams.gpus <= 1 else None  # passing num_workers breaks lightning for multigpu

    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
        if self.model.config.model_type == "bart":
            freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
        else:
            freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                freeze_params(d.embed_tokens)

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def ids_to_clean_text(self, generated_ids: List[int]):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
108
        )
109
        return lmap(str.strip, gen_text)
110

111
    def _step(self, batch: dict) -> Tuple:
112
        pad_token_id = self.tokenizer.pad_token_id
113
        source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
114
115
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone()
116
        lm_labels[y[:, 1:] == pad_token_id] = -100
117
        outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
118
        loss = outputs[0]
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        return (loss,)

    def training_step(self, batch, batch_idx) -> Dict:
        loss_tensors = self._step(batch)
        logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        return {"loss": loss_tensors[0], "log": logs}

    def validation_step(self, batch, batch_idx) -> Dict:
        return self._generative_step(batch)

    def validation_end(self, outputs, prefix="val") -> Dict:
        self.step_count += 1
        losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
        loss = losses["loss"]
        rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
        rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
        rouges.update({k: v.item() for k, v in losses.items()})
        losses.update(rouges)
        metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
        metrics["step_count"] = self.step_count
        self.save_metrics(metrics, prefix)  # writes to self.metrics_save_path
        preds = flatten_list([x["preds"] for x in outputs])
        return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}

    def save_metrics(self, metrics, prefix) -> None:
        self.metrics[prefix].append(metrics)
        pickle_save(self.metrics, self.metrics_save_path)

    def _generative_step(self, batch):
148
149
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
150
151
152
153
154
155
156
157
158
159
160
161
162
        # TODO(SS): task specific params

        t0 = time.time()
        generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
        gen_time = time.time() - t0
        preds = self.ids_to_clean_text(generated_ids)
        target = self.ids_to_clean_text(y)
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        rouge: Dict = calculate_rouge(preds, target)
        summ_len = np.mean(lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
        return base_metrics
163

164
165
    def test_step(self, batch, batch_idx):
        return self._generative_step(batch)
166
167

    def test_end(self, outputs):
168
        return self.validation_end(outputs, prefix="test")
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    def test_epoch_end(self, outputs):
        output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
        output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
        # write predictions and targets for later rouge evaluation.
        with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
            for output_batch in outputs:
                p_writer.writelines(s + "\n" for s in output_batch["preds"])
                t_writer.writelines(s + "\n" for s in output_batch["target"])
            p_writer.close()
            t_writer.close()

        return self.test_end(outputs)

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def validation_epoch_end(self, outputs):
        self.validation_end(outputs, "val")

    def get_dataset(self, type_path) -> SummarizationDataset:
        n_obs = self.n_obs[type_path]
        max_target_length = self.target_lens[type_path]
        dataset = SummarizationDataset(
            self.tokenizer,
            type_path=type_path,
            n_obs=n_obs,
            max_target_length=max_target_length,
            **self.dataset_kwargs,
        )
        return dataset

198
    def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        dataset = self.get_dataset(type_path)
        sampler = None
        if self.hparams.sortish_sampler and type_path == "train":
            assert self.hparams.gpus <= 1  # TODO: assert earlier
            sampler = dataset.make_sortish_sampler(batch_size)
            shuffle = False

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=dataset.collate_fn,
            shuffle=shuffle,
            num_workers=self.num_workers,
            sampler=sampler,
        )
214
215
216
        return dataloader

    def train_dataloader(self) -> DataLoader:
217
        dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
218
        t_total = (
219
            (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
220
221
222
223
224
225
226
227
228
            // self.hparams.gradient_accumulation_steps
            * float(self.hparams.num_train_epochs)
        )
        scheduler = get_linear_schedule_with_warmup(
            self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
        )
        self.lr_scheduler = scheduler
        return dataloader

229
230
    def val_dataloader(self) -> DataLoader:
        return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
231

232
233
    def test_dataloader(self) -> DataLoader:
        return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
234
235
236
237

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        BaseTransformer.add_model_specific_args(parser, root_dir)
238
        add_generic_args(parser, root_dir)
239
        parser.add_argument(
240
            "--max_source_length",
241
242
243
244
245
            default=1024,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
246
247
248
249
250
251
252
        parser.add_argument(
            "--max_target_length",
            default=56,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        parser.add_argument(
            "--val_max_target_length",
            default=142,  # these defaults are optimized for CNNDM. For xsum, see README.md.
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
        parser.add_argument(
            "--test_max_target_length",
            default=142,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )
267
268
269
270
        parser.add_argument(
            "--data_dir",
            type=str,
            required=True,
271
            help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
272
        )
273
274
275
276
277
278
279
        parser.add_argument("--freeze_encoder", action="store_true")
        parser.add_argument("--freeze_embeds", action="store_true")
        parser.add_argument("--sortish_sampler", action="store_true", default=False)
        parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
        parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
        parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
280
281
282
        return parser


283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
        model: BaseTransformer = SummarizationModule(args)
    if (
        args.logger == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name)
    elif args.logger == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        # TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
        logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
        logger=logger,
        # TODO: early stopping callback seems messed up
    )
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)
    trainer.test(model)  # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
    return model
324
325
326
327


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
328
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
329
330
331
    args = parser.parse_args()

    main(args)