"test/pipelines/pipelines-it-local.yml" did not exist on "f60bf1d97bd2159c01facdb3ca8b0a61363daf52"
lightning_base.py 12 KB
Newer Older
1
import argparse
2
import logging
3
4
import os
import random
5
6
from pathlib import Path
from typing import Any, Dict
7
8
9
10

import numpy as np
import pytorch_lightning as pl
import torch
11
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
12
13
14

from transformers import (
    AdamW,
15
16
17
18
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForQuestionAnswering,
19
    AutoModelForSeq2SeqLM,
20
21
22
23
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    AutoModelWithLMHead,
    AutoTokenizer,
24
25
    PretrainedConfig,
    PreTrainedTokenizer,
26
27
28
29
    get_linear_schedule_with_warmup,
)


30
31
32
logger = logging.getLogger(__name__)


33
34
35
36
37
38
39
MODEL_MODES = {
    "base": AutoModel,
    "sequence-classification": AutoModelForSequenceClassification,
    "question-answering": AutoModelForQuestionAnswering,
    "pretraining": AutoModelForPreTraining,
    "token-classification": AutoModelForTokenClassification,
    "language-modeling": AutoModelWithLMHead,
40
41
    "summarization": AutoModelForSeq2SeqLM,
    "translation": AutoModelForSeq2SeqLM,
42
43
44
}


45
def set_seed(args: argparse.Namespace):
46
47
48
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
49
    if args.gpus > 0:
50
51
52
53
        torch.cuda.manual_seed_all(args.seed)


class BaseTransformer(pl.LightningModule):
54
55
56
57
58
59
60
61
62
63
    def __init__(
        self,
        hparams: argparse.Namespace,
        num_labels=None,
        mode="base",
        config=None,
        tokenizer=None,
        model=None,
        **config_kwargs
    ):
64
        """Initialize a model, tokenizer and config."""
Julien Chaumond's avatar
Julien Chaumond committed
65
        super().__init__()
66
        self.hparams = hparams  # TODO: move to self.save_hyperparameters()
67
68
69
        self.step_count = 0
        self.tfmr_ckpts = {}
        self.output_dir = Path(self.hparams.output_dir)
70
        cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        if config is None:
            self.config = AutoConfig.from_pretrained(
                self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
                **({"num_labels": num_labels} if num_labels is not None else {}),
                cache_dir=cache_dir,
                **config_kwargs,
            )
        else:
            self.config: PretrainedConfig = config
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
                cache_dir=cache_dir,
            )
        else:
            self.tokenizer: PreTrainedTokenizer = tokenizer
87
        self.model_type = MODEL_MODES[mode]
88
89
90
91
92
93
94
95
96
97
98
99
        if model is None:
            self.model = self.model_type.from_pretrained(
                self.hparams.model_name_or_path,
                from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
                config=self.config,
                cache_dir=cache_dir,
            )
        else:
            self.model = model

    def load_hf_checkpoint(self, *args, **kwargs):
        self.model = self.model_type.from_pretrained(*args, **kwargs)
100
101
102

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
103
        model = self.model
104
105
106
107
108
109
110
111
112
113
114
115
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
116
        self.opt = optimizer
117
118
        return [optimizer]

119
120
121
122
123
124
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
        if self.trainer.use_tpu:
            xm.optimizer_step(optimizer)
        else:
            optimizer.step()
        optimizer.zero_grad()
125
126
127
        self.lr_scheduler.step()  # By default, PL will only step every epoch.
        lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
        self.logger.log_metrics(lrs)
128

129
130
131
    def test_step(self, batch, batch_nb):
        return self.validation_step(batch, batch_nb)

132
    def test_epoch_end(self, outputs):
133
134
135
        return self.validation_end(outputs)

    def train_dataloader(self):
136
137
138
139
140
141
142
143
144
145
146
147
148
        train_batch_size = self.hparams.train_batch_size
        dataloader = self.load_dataset("train", train_batch_size)

        t_total = (
            (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu)))
            // self.hparams.gradient_accumulation_steps
            * float(self.hparams.num_train_epochs)
        )
        scheduler = get_linear_schedule_with_warmup(
            self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
        )
        self.lr_scheduler = scheduler
        return dataloader
149
150
151
152
153
154
155

    def val_dataloader(self):
        return self.load_dataset("dev", self.hparams.eval_batch_size)

    def test_dataloader(self):
        return self.load_dataset("test", self.hparams.eval_batch_size)

156
157
158
159
160
161
162
163
164
165
    def _feature_file(self, mode):
        return os.path.join(
            self.hparams.data_dir,
            "cached_{}_{}_{}".format(
                mode,
                list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
                str(self.hparams.max_seq_length),
            ),
        )

166
167
168
169
170
171
172
173
174
    @pl.utilities.rank_zero_only
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        save_path = self.output_dir.joinpath("best_tfmr")
        save_path.mkdir(exist_ok=True)
        self.model.config.save_step = self.step_count
        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)
        self.tfmr_ckpts[self.step_count] = save_path

175
176
177
178
179
180
181
    @staticmethod
    def add_model_specific_args(parser, root_dir):
        parser.add_argument(
            "--model_name_or_path",
            default=None,
            type=str,
            required=True,
Julien Chaumond's avatar
Julien Chaumond committed
182
            help="Path to pretrained model or model identifier from huggingface.co/models",
183
184
185
186
187
188
        )
        parser.add_argument(
            "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
        )
        parser.add_argument(
            "--tokenizer_name",
189
            default=None,
190
191
192
193
194
195
196
197
198
199
200
201
            type=str,
            help="Pretrained tokenizer name or path if not the same as model_name",
        )
        parser.add_argument(
            "--cache_dir",
            default="",
            type=str,
            help="Where do you want to store the pre-trained models downloaded from s3",
        )
        parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
        parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
        parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
202
        parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
203
        parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
204
205
206
207
208
209
210
211
        parser.add_argument(
            "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
        )

        parser.add_argument("--train_batch_size", default=32, type=int)
        parser.add_argument("--eval_batch_size", default=32, type=int)


212
class LoggingCallback(pl.Callback):
213
    @rank_zero_only
214
    def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        rank_zero_info("***** Validation results *****")
        metrics = trainer.callback_metrics
        # Log results
        for key in sorted(metrics):
            if key not in ["log", "progress_bar"]:
                rank_zero_info("{} = {}\n".format(key, str(metrics[key])))

    @rank_zero_only
    def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        logger.info("***** Test results *****")
        metrics = trainer.callback_metrics
        # Log and save results to file
        output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
        with open(output_test_results_file, "w") as writer:
229
230
231
            for key in sorted(metrics):
                if key not in ["log", "progress_bar"]:
                    logger.info("{} = {}\n".format(key, str(metrics[key])))
232
                    writer.write("{} = {}\n".format(key, str(metrics[key])))
233
234


235
236
def add_generic_args(parser, root_dir) -> None:
    #  TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )

    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
258
259
    parser.add_argument("--fast_dev_run", action="store_true")
    parser.add_argument("--gpus", type=int, default=1)
260
    parser.add_argument("--n_tpu_cores", type=int, default=0)
261
262
263
264
265
266
267
268
269
270
271
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )

    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--val_check_interval", default=1.0, type=float)


def generic_train(
    model: BaseTransformer,
    args: argparse.Namespace,
    early_stopping_callback=False,
    logger=True,  # can pass WandbLogger() here
    extra_callbacks=[],
    checkpoint_callback=None,
    logging_callback=None,
    **extra_train_kwargs
):
286
287
    # init model
    set_seed(args)
288
289
290
291
292
293
294
295
    odir = Path(model.hparams.output_dir)
    odir.mkdir(exist_ok=True)
    if checkpoint_callback is None:
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
        )
    if logging_callback is None:
        logging_callback = LoggingCallback()
296

297
    train_params = {}
298

srush's avatar
srush committed
299
300
301
302
    if args.fp16:
        train_params["use_amp"] = args.fp16
        train_params["amp_level"] = args.fp16_opt_level

303
304
305
306
307
308
309
    if args.n_tpu_cores > 0:
        global xm
        import torch_xla.core.xla_model as xm

        train_params["num_tpu_cores"] = args.n_tpu_cores
        train_params["gpus"] = 0

310
    if args.gpus > 1:
srush's avatar
srush committed
311
312
        train_params["distributed_backend"] = "ddp"

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    trainer = pl.Trainer(
        logger=logger,
        accumulate_grad_batches=args.gradient_accumulation_steps,
        gpus=args.gpus,
        max_epochs=args.num_train_epochs,
        early_stop_callback=early_stopping_callback,
        gradient_clip_val=args.max_grad_norm,
        checkpoint_callback=checkpoint_callback,
        callbacks=[logging_callback] + extra_callbacks,
        fast_dev_run=args.fast_dev_run,
        val_check_interval=args.val_check_interval,
        weights_summary=None,
        resume_from_checkpoint=args.resume_from_checkpoint,
        **train_params,
    )
srush's avatar
srush committed
328

329
330
    if args.do_train:
        trainer.fit(model)
331
332
    trainer.logger.log_hyperparams(args)
    trainer.logger.save()
333
    return trainer