trainer.py 34.2 KB
Newer Older
Julien Chaumond's avatar
Julien Chaumond committed
1
import logging
2
import math
Julien Chaumond's avatar
Julien Chaumond committed
3
4
5
6
import os
import random
import re
import shutil
7
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
8
9
from contextlib import contextmanager
from pathlib import Path
Lysandre's avatar
Lysandre committed
10
from typing import Callable, Dict, List, Optional, Tuple
Julien Chaumond's avatar
Julien Chaumond committed
11
12
13

import numpy as np
import torch
14
from packaging import version
Julien Chaumond's avatar
Julien Chaumond committed
15
16
17
18
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
19
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
20
from tqdm.auto import tqdm, trange
Julien Chaumond's avatar
Julien Chaumond committed
21

22
from .data.data_collator import DataCollator, default_data_collator
Julien Chaumond's avatar
Julien Chaumond committed
23
24
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
25
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available
26
from .training_args import TrainingArguments, is_torch_tpu_available
Julien Chaumond's avatar
Julien Chaumond committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40


try:
    from apex import amp

    _has_apex = True
except ImportError:
    _has_apex = False


def is_apex_available():
    return _has_apex


41
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
42
43
44
45
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

Julien Chaumond's avatar
Julien Chaumond committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
try:
    from torch.utils.tensorboard import SummaryWriter

    _has_tensorboard = True
except ImportError:
    try:
        from tensorboardX import SummaryWriter

        _has_tensorboard = True
    except ImportError:
        _has_tensorboard = False


def is_tensorboard_available():
    return _has_tensorboard


63
if is_wandb_available():
64
65
66
    import wandb


Julien Chaumond's avatar
Julien Chaumond committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
logger = logging.getLogger(__name__)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # ^^ safe to call this function even if cuda is not available


@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
81
    Decorator to make all processes in distributed training wait for each local_master to do something.
Julien Chaumond's avatar
Julien Chaumond committed
82
83
84
85
86
87
88
89
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples


Lysandre Debut's avatar
Lysandre Debut committed
134
135
136
137
138
139
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


Julien Chaumond's avatar
Julien Chaumond committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
154
    optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
155
156
    global_step: Optional[int] = None
    epoch: Optional[float] = None
Julien Chaumond's avatar
Julien Chaumond committed
157
158
159
160
161
162
163
164
165
166

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
167
        tb_writer: Optional["SummaryWriter"] = None,
168
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
Julien Chaumond's avatar
Julien Chaumond committed
169
170
171
172
173
174
175
176
177
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
178
        self.model = model.to(args.device)
Julien Chaumond's avatar
Julien Chaumond committed
179
        self.args = args
180
        self.data_collator = data_collator if data_collator is not None else default_data_collator
Julien Chaumond's avatar
Julien Chaumond committed
181
182
183
184
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
185
        self.optimizers = optimizers
186
187
        if tb_writer is not None:
            self.tb_writer = tb_writer
188
        elif is_tensorboard_available() and self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
189
190
191
192
193
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
194
195
196
        if is_wandb_available():
            self._setup_wandb()
        else:
197
            logger.info(
198
199
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
200
            )
Julien Chaumond's avatar
Julien Chaumond committed
201
202
        set_seed(self.args.seed)
        # Create output directory if needed
203
        if self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
204
            os.makedirs(self.args.output_dir, exist_ok=True)
205
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
206
207
208
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
209
210
211
212
213
214
215
216
217
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
Julien Chaumond's avatar
Julien Chaumond committed
218
219
220
221

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
222
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
223
224
225
226
227
228
229
230
231
            train_sampler = get_tpu_sampler(self.train_dataset)
        else:
            train_sampler = (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

        data_loader = DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
232
233
234
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
235
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
236
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
237
238
        )

Lysandre Debut's avatar
Lysandre Debut committed
239
240
        return data_loader

Julien Chaumond's avatar
Julien Chaumond committed
241
242
243
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
Lysandre Debut's avatar
Lysandre Debut committed
244

245
246
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

247
        if is_torch_tpu_available():
248
249
250
251
252
253
254
            sampler = SequentialDistributedSampler(
                eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(eval_dataset)
        else:
            sampler = SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
255
256

        data_loader = DataLoader(
257
            eval_dataset,
Lysandre Debut's avatar
Lysandre Debut committed
258
            sampler=sampler,
Julien Chaumond's avatar
Julien Chaumond committed
259
            batch_size=self.args.eval_batch_size,
260
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
261
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
262
263
        )

Lysandre Debut's avatar
Lysandre Debut committed
264
265
        return data_loader

Julien Chaumond's avatar
Julien Chaumond committed
266
267
    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
268
        if is_torch_tpu_available():
269
270
271
272
273
274
275
            sampler = SequentialDistributedSampler(
                test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(test_dataset)
        else:
            sampler = SequentialSampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
276
277

        data_loader = DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
278
            test_dataset,
Lysandre Debut's avatar
Lysandre Debut committed
279
            sampler=sampler,
Julien Chaumond's avatar
Julien Chaumond committed
280
            batch_size=self.args.eval_batch_size,
281
            collate_fn=self.data_collator,
282
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
283
284
        )

Lysandre Debut's avatar
Lysandre Debut committed
285
286
        return data_loader

Julien Chaumond's avatar
Julien Chaumond committed
287
288
289
    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
290
291
292
293
294
295
296
297
298
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well.
        If you want to use something else, you can pass a tuple in the Trainer's init,
        or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers
Julien Chaumond's avatar
Julien Chaumond committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
        )
        return optimizer, scheduler

317
    def _setup_wandb(self):
318
319
320
        """
        Setup the optional Weights & Biases (`wandb`) integration.

321
322
323
324
325
326
327
328
329
330
331
        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
332
        """
333
334
335
        if self.is_world_master():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
336
            )
337
338
339
340
341
342
            wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
            # keep track of model topology and gradients
            if os.getenv("WANDB_WATCH") != "false":
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )
343

344
    def num_examples(self, dataloader: DataLoader) -> int:
345
346
347
        """
        Helper to get num of examples from a DataLoader, by accessing its Dataset.
        """
348
        return len(dataloader.dataset)
349

Julien Chaumond's avatar
Julien Chaumond committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
378
379
380
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
Julien Chaumond's avatar
Julien Chaumond committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
404
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
Julien Chaumond's avatar
Julien Chaumond committed
405
406

        # Train!
407
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
408
409
410
411
412
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
413
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
414
            )
Julien Chaumond's avatar
Julien Chaumond committed
415
        logger.info("***** Running training *****")
416
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
Julien Chaumond's avatar
Julien Chaumond committed
417
        logger.info("  Num Epochs = %d", num_train_epochs)
418
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
Lysandre Debut's avatar
Lysandre Debut committed
419
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
420
421
422
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

423
424
        self.global_step = 0
        self.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
425
426
427
428
429
430
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
431
432
433
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
Julien Chaumond's avatar
Julien Chaumond committed
434
435
436
437
438
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
439
                logger.info("  Continuing training from global step %d", self.global_step)
Julien Chaumond's avatar
Julien Chaumond committed
440
441
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
442
                self.global_step = 0
Julien Chaumond's avatar
Julien Chaumond committed
443
444
445
446
447
448
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
Lysandre Debut's avatar
Lysandre Debut committed
449
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
Julien Chaumond's avatar
Julien Chaumond committed
450
451
        )
        for epoch in train_iterator:
452
453
454
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

455
            if is_torch_tpu_available():
456
457
458
459
460
461
462
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())

Julien Chaumond's avatar
Julien Chaumond committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

482
                    if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
483
484
485
486
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

Julien Chaumond's avatar
Julien Chaumond committed
487
488
                    scheduler.step()
                    model.zero_grad()
489
490
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)
Julien Chaumond's avatar
Julien Chaumond committed
491

492
493
494
495
496
                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
497
498
499
500
501
502
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )
503
504
505
506
507
508
509
                        logging_loss = tr_loss

                        self._log(logs)

                        if self.args.evaluate_during_training:
                            self.evaluate()

510
511
512
513
514
515
516
517
518
519
520
521
522
                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
523
                            self._rotate_checkpoints()
524

525
                        if is_torch_tpu_available():
526
527
528
529
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
530
531
532
                            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

533
                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
534
535
                    epoch_iterator.close()
                    break
536
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
537
538
                train_iterator.close()
                break
Lysandre Debut's avatar
Lysandre Debut committed
539
540
541
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())
Julien Chaumond's avatar
Julien Chaumond committed
542
543
544
545
546

        if self.tb_writer:
            self.tb_writer.close()

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
547
548
549
550
551
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
552
553
554
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
555
556
        if self.tb_writer:
            for k, v in logs.items():
557
558
559
560
561
562
563
564
565
566
567
568
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
569
            self.tb_writer.flush()
570
        if is_wandb_available():
571
572
            if self.is_world_master():
                wandb.log(logs, step=self.global_step)
573
        output = {**logs, **{"step": self.global_step}}
574
575
576
        if iterator is not None:
            iterator.write(output)
        else:
577
            logger.info(output)
Julien Chaumond's avatar
Julien Chaumond committed
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601

    def _training_step(
        self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
    ) -> float:
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

Lysandre Debut's avatar
Lysandre Debut committed
602
    def is_local_master(self) -> bool:
603
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
604
605
606
607
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

Julien Chaumond's avatar
Julien Chaumond committed
608
609
610
611
612
    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
613
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
614
615
616
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
617
618
619
620
621
622

    def save_model(self, output_dir: Optional[str] = None):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().

623
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
624
        """
625

626
        if is_torch_tpu_available():
627
628
            self._save_tpu(output_dir)
        elif self.is_world_master():
Julien Chaumond's avatar
Julien Chaumond committed
629
630
            self._save(output_dir)

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

Julien Chaumond's avatar
Julien Chaumond committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

663
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
664
665
666
667
668

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
669
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
670
671
672
673
674
675
676
677
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
678
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
679
680
681
682
683
684
685
686
687
688
689
690
691
692
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(
693
        self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
Julien Chaumond's avatar
Julien Chaumond committed
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    ) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.

        Args:
            eval_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader, description="Evaluation")
Lysandre Debut's avatar
Lysandre Debut committed
712

713
714
        self._log(output.metrics)

Lysandre Debut's avatar
Lysandre Debut committed
715
716
717
718
        if self.args.tpu_metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Julien Chaumond's avatar
Julien Chaumond committed
719
720
721
722
723
724
725
726
727
728
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
729

Julien Chaumond's avatar
Julien Chaumond committed
730
731
732
733
734
735
736
737
738
739
740
741
742
        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

743
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
744
        # multi-gpu eval
745
746
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
Julien Chaumond's avatar
Julien Chaumond committed
747
748
        else:
            model = self.model
749
750
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
751

752
        batch_size = dataloader.batch_size
Julien Chaumond's avatar
Julien Chaumond committed
753
        logger.info("***** Running %s *****", description)
754
755
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
756
        eval_losses: List[float] = []
757
758
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
Julien Chaumond's avatar
Julien Chaumond committed
759
760
        model.eval()

761
        if is_torch_tpu_available():
762
763
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

Julien Chaumond's avatar
Julien Chaumond committed
764
        for inputs in tqdm(dataloader, desc=description):
Suraj Patil's avatar
Suraj Patil committed
765
            has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
Julien Chaumond's avatar
Julien Chaumond committed
766
767
768
769
770
771
772
773
774
775
776
777
778
779

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
780
                    preds = logits.detach()
Julien Chaumond's avatar
Julien Chaumond committed
781
                else:
782
                    preds = torch.cat((preds, logits.detach()), dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
783
784
                if inputs.get("labels") is not None:
                    if label_ids is None:
785
                        label_ids = inputs["labels"].detach()
Julien Chaumond's avatar
Julien Chaumond committed
786
                    else:
787
                        label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
788

789
790
791
792
793
794
        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
795
        elif is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
796
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
797
798
799
800
801
802
803
804
805
806
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()
Lysandre Debut's avatar
Lysandre Debut committed
807

Julien Chaumond's avatar
Julien Chaumond committed
808
809
810
811
812
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
813
814
815
816
817
818
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
819
820

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
821
822
823
824
825
826
827
828
829
830
831
832

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output