trainer.py 34 KB
Newer Older
Julien Chaumond's avatar
Julien Chaumond committed
1
import logging
2
import math
Julien Chaumond's avatar
Julien Chaumond committed
3
4
5
6
7
8
import os
import random
import re
import shutil
from contextlib import contextmanager
from pathlib import Path
Lysandre's avatar
Lysandre committed
9
from typing import Callable, Dict, List, Optional, Tuple
Julien Chaumond's avatar
Julien Chaumond committed
10
11
12

import numpy as np
import torch
13
from packaging import version
Julien Chaumond's avatar
Julien Chaumond committed
14
15
16
17
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
18
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
19
from tqdm.auto import tqdm, trange
Julien Chaumond's avatar
Julien Chaumond committed
20

21
from .data.data_collator import DataCollator, default_data_collator
Julien Chaumond's avatar
Julien Chaumond committed
22
23
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
Julien Plu's avatar
Julien Plu committed
24
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
25
from .training_args import TrainingArguments, is_torch_tpu_available
Julien Chaumond's avatar
Julien Chaumond committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39


try:
    from apex import amp

    _has_apex = True
except ImportError:
    _has_apex = False


def is_apex_available():
    return _has_apex


40
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
41
42
43
44
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

Julien Chaumond's avatar
Julien Chaumond committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
try:
    from torch.utils.tensorboard import SummaryWriter

    _has_tensorboard = True
except ImportError:
    try:
        from tensorboardX import SummaryWriter

        _has_tensorboard = True
    except ImportError:
        _has_tensorboard = False


def is_tensorboard_available():
    return _has_tensorboard


62
63
64
try:
    import wandb

65
66
67
68
69
70
    wandb.ensure_configured()
    if wandb.api.api_key is None:
        _has_wandb = False
        wandb.termwarn("W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.")
    else:
        _has_wandb = False if os.getenv("WANDB_DISABLED") else True
71
except (ImportError, AttributeError):
72
73
74
75
76
77
78
    _has_wandb = False


def is_wandb_available():
    return _has_wandb


Julien Chaumond's avatar
Julien Chaumond committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
logger = logging.getLogger(__name__)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # ^^ safe to call this function even if cuda is not available


@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
93
    Decorator to make all processes in distributed training wait for each local_master to do something.
Julien Chaumond's avatar
Julien Chaumond committed
94
95
96
97
98
99
100
101
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples


Lysandre Debut's avatar
Lysandre Debut committed
146
147
148
149
150
151
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


Julien Chaumond's avatar
Julien Chaumond committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
166
    optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
167
168
    global_step: Optional[int] = None
    epoch: Optional[float] = None
Julien Chaumond's avatar
Julien Chaumond committed
169
170
171
172
173
174
175
176
177
178

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
179
        tb_writer: Optional["SummaryWriter"] = None,
180
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
Julien Chaumond's avatar
Julien Chaumond committed
181
182
183
184
185
186
187
188
189
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
190
        self.model = model.to(args.device)
Julien Chaumond's avatar
Julien Chaumond committed
191
        self.args = args
192
        self.data_collator = data_collator if data_collator is not None else default_data_collator
Julien Chaumond's avatar
Julien Chaumond committed
193
194
195
196
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
197
        self.optimizers = optimizers
198
199
        if tb_writer is not None:
            self.tb_writer = tb_writer
200
        elif is_tensorboard_available() and self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
201
202
203
204
205
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
206
207
208
        if is_wandb_available():
            self._setup_wandb()
        else:
209
            logger.info(
210
211
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
212
            )
Julien Chaumond's avatar
Julien Chaumond committed
213
214
        set_seed(self.args.seed)
        # Create output directory if needed
215
        if self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
216
            os.makedirs(self.args.output_dir, exist_ok=True)
217
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
218
219
220
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
Julien Chaumond's avatar
Julien Chaumond committed
221
222
223
224

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
225
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
226
227
228
229
230
231
232
233
234
            train_sampler = get_tpu_sampler(self.train_dataset)
        else:
            train_sampler = (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

        data_loader = DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
235
236
237
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
238
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
239
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
240
241
        )

Lysandre Debut's avatar
Lysandre Debut committed
242
243
        return data_loader

Julien Chaumond's avatar
Julien Chaumond committed
244
245
246
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
Lysandre Debut's avatar
Lysandre Debut committed
247

248
249
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

250
        if is_torch_tpu_available():
251
252
253
254
255
256
257
            sampler = SequentialDistributedSampler(
                eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(eval_dataset)
        else:
            sampler = SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
258
259

        data_loader = DataLoader(
260
            eval_dataset,
Lysandre Debut's avatar
Lysandre Debut committed
261
            sampler=sampler,
Julien Chaumond's avatar
Julien Chaumond committed
262
            batch_size=self.args.eval_batch_size,
263
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
264
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
265
266
        )

Lysandre Debut's avatar
Lysandre Debut committed
267
268
        return data_loader

Julien Chaumond's avatar
Julien Chaumond committed
269
270
    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
271
        if is_torch_tpu_available():
272
273
274
275
276
277
278
            sampler = SequentialDistributedSampler(
                test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(test_dataset)
        else:
            sampler = SequentialSampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
279
280

        data_loader = DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
281
            test_dataset,
Lysandre Debut's avatar
Lysandre Debut committed
282
            sampler=sampler,
Julien Chaumond's avatar
Julien Chaumond committed
283
            batch_size=self.args.eval_batch_size,
284
            collate_fn=self.data_collator,
285
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
286
287
        )

Lysandre Debut's avatar
Lysandre Debut committed
288
289
        return data_loader

Julien Chaumond's avatar
Julien Chaumond committed
290
291
292
    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
293
294
295
296
297
298
299
300
301
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well.
        If you want to use something else, you can pass a tuple in the Trainer's init,
        or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers
Julien Chaumond's avatar
Julien Chaumond committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
        )
        return optimizer, scheduler

320
    def _setup_wandb(self):
321
322
323
        """
        Setup the optional Weights & Biases (`wandb`) integration.

324
325
326
327
328
329
330
331
332
333
334
        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
335
        """
336
337
338
        if self.is_world_master():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
339
            )
340
341
342
343
344
345
            wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
            # keep track of model topology and gradients
            if os.getenv("WANDB_WATCH") != "false":
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )
346

347
    def num_examples(self, dataloader: DataLoader) -> int:
348
349
350
        """
        Helper to get num of examples from a DataLoader, by accessing its Dataset.
        """
351
        return len(dataloader.dataset)
352

Julien Chaumond's avatar
Julien Chaumond committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
381
382
383
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
Julien Chaumond's avatar
Julien Chaumond committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
407
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
Julien Chaumond's avatar
Julien Chaumond committed
408
409

        # Train!
410
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
411
412
413
414
415
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
416
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
417
            )
Julien Chaumond's avatar
Julien Chaumond committed
418
        logger.info("***** Running training *****")
419
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
Julien Chaumond's avatar
Julien Chaumond committed
420
        logger.info("  Num Epochs = %d", num_train_epochs)
421
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
Lysandre Debut's avatar
Lysandre Debut committed
422
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
423
424
425
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

426
427
        self.global_step = 0
        self.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
428
429
430
431
432
433
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
434
435
436
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
Julien Chaumond's avatar
Julien Chaumond committed
437
438
439
440
441
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
442
                logger.info("  Continuing training from global step %d", self.global_step)
Julien Chaumond's avatar
Julien Chaumond committed
443
444
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
445
                self.global_step = 0
Julien Chaumond's avatar
Julien Chaumond committed
446
447
448
449
450
451
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
Lysandre Debut's avatar
Lysandre Debut committed
452
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
Julien Chaumond's avatar
Julien Chaumond committed
453
454
        )
        for epoch in train_iterator:
455
456
457
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

458
            if is_torch_tpu_available():
459
460
461
462
463
464
465
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())

Julien Chaumond's avatar
Julien Chaumond committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

485
                    if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
486
487
488
489
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

Julien Chaumond's avatar
Julien Chaumond committed
490
491
                    scheduler.step()
                    model.zero_grad()
492
493
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)
Julien Chaumond's avatar
Julien Chaumond committed
494

495
496
497
498
499
                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
500
501
502
503
504
505
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )
506
507
508
509
510
511
512
                        logging_loss = tr_loss

                        self._log(logs)

                        if self.args.evaluate_during_training:
                            self.evaluate()

513
514
515
516
517
518
519
520
521
522
523
524
525
                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
526
                            self._rotate_checkpoints()
527

528
                        if is_torch_tpu_available():
529
530
531
532
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
533
534
535
                            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

536
                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
537
538
                    epoch_iterator.close()
                    break
539
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
540
541
                train_iterator.close()
                break
Lysandre Debut's avatar
Lysandre Debut committed
542
543
544
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())
Julien Chaumond's avatar
Julien Chaumond committed
545
546
547
548
549

        if self.tb_writer:
            self.tb_writer.close()

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
550
551
552
553
554
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
555
556
557
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
558
559
        if self.tb_writer:
            for k, v in logs.items():
560
561
562
563
564
565
566
567
568
569
570
571
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
572
            self.tb_writer.flush()
573
        if is_wandb_available():
574
575
            if self.is_world_master():
                wandb.log(logs, step=self.global_step)
576
        output = {**logs, **{"step": self.global_step}}
577
578
579
        if iterator is not None:
            iterator.write(output)
        else:
580
            logger.info(output)
Julien Chaumond's avatar
Julien Chaumond committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604

    def _training_step(
        self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
    ) -> float:
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

Lysandre Debut's avatar
Lysandre Debut committed
605
    def is_local_master(self) -> bool:
606
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
607
608
609
610
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

Julien Chaumond's avatar
Julien Chaumond committed
611
612
613
614
615
    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
616
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
617
618
619
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
620
621
622
623
624
625

    def save_model(self, output_dir: Optional[str] = None):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().

626
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
627
        """
628

629
        if is_torch_tpu_available():
630
631
            self._save_tpu(output_dir)
        elif self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
632
633
            self._save(output_dir)

634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

Julien Chaumond's avatar
Julien Chaumond committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

666
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
667
668
669
670
671

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
672
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
673
674
675
676
677
678
679
680
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
681
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
682
683
684
685
686
687
688
689
690
691
692
693
694
695
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(
696
        self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
Julien Chaumond's avatar
Julien Chaumond committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
    ) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.

        Args:
            eval_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader, description="Evaluation")
Lysandre Debut's avatar
Lysandre Debut committed
715

716
717
        self._log(output.metrics)

Lysandre Debut's avatar
Lysandre Debut committed
718
719
720
721
        if self.args.tpu_metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Julien Chaumond's avatar
Julien Chaumond committed
722
723
724
725
726
727
728
729
730
731
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
732

Julien Chaumond's avatar
Julien Chaumond committed
733
734
735
736
737
738
739
740
741
742
743
744
745
        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

746
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
747
        # multi-gpu eval
748
749
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
Julien Chaumond's avatar
Julien Chaumond committed
750
751
        else:
            model = self.model
752
753
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
754

755
        batch_size = dataloader.batch_size
Julien Chaumond's avatar
Julien Chaumond committed
756
        logger.info("***** Running %s *****", description)
757
758
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
759
        eval_losses: List[float] = []
760
761
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
Julien Chaumond's avatar
Julien Chaumond committed
762
763
        model.eval()

764
        if is_torch_tpu_available():
765
766
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

Julien Chaumond's avatar
Julien Chaumond committed
767
        for inputs in tqdm(dataloader, desc=description):
Suraj Patil's avatar
Suraj Patil committed
768
            has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
Julien Chaumond's avatar
Julien Chaumond committed
769
770
771
772
773
774
775
776
777
778
779
780
781
782

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
783
                    preds = logits.detach()
Julien Chaumond's avatar
Julien Chaumond committed
784
                else:
785
                    preds = torch.cat((preds, logits.detach()), dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
786
787
                if inputs.get("labels") is not None:
                    if label_ids is None:
788
                        label_ids = inputs["labels"].detach()
Julien Chaumond's avatar
Julien Chaumond committed
789
                    else:
790
                        label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
791

792
793
794
795
796
797
        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
798
        elif is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
799
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
800
801
802
803
804
805
806
807
808
809
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()
Lysandre Debut's avatar
Lysandre Debut committed
810

Julien Chaumond's avatar
Julien Chaumond committed
811
812
813
814
815
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
816
817
818
819
820
821
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
822
823

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
824
825
826
827
828
829
830
831
832
833
834
835

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output