trainer.py 70.1 KB
Newer Older
1
import inspect
2
import json
3
import math
Julien Chaumond's avatar
Julien Chaumond committed
4
5
6
import os
import re
import shutil
7
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
8
9
from contextlib import contextmanager
from pathlib import Path
10
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
11
12
13

import numpy as np
import torch
14
from packaging import version
Julien Chaumond's avatar
Julien Chaumond committed
15
16
17
18
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
19
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
20
from tqdm.auto import tqdm, trange
Julien Chaumond's avatar
Julien Chaumond committed
21

22
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
23
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
24
25
26
27
28
29
30
from .integrations import (
    default_hp_search_backend,
    is_comet_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
31
32
    run_hp_search_optuna,
    run_hp_search_ray,
33
)
Sylvain Gugger's avatar
Sylvain Gugger committed
34
from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
Julien Chaumond's avatar
Julien Chaumond committed
35
36
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
37
from .tokenization_utils_base import PreTrainedTokenizerBase
38
39
40
41
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
42
    EvaluationStrategy,
43
44
    HPSearchBackend,
    PredictionOutput,
45
    TrainerState,
46
47
48
    TrainOutput,
    default_compute_objective,
    default_hp_space,
49
50
    distributed_broadcast_scalars,
    distributed_concat,
Sylvain Gugger's avatar
Sylvain Gugger committed
51
52
53
    nested_concat,
    nested_numpify,
    nested_xla_mesh_reduce,
54
55
    set_seed,
)
Patrick von Platen's avatar
Patrick von Platen committed
56
from .training_args import TrainingArguments
Lysandre Debut's avatar
Lysandre Debut committed
57
from .utils import logging
Julien Chaumond's avatar
Julien Chaumond committed
58
59


60
61
62
_use_native_amp = False
_use_apex = False

63
64
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."

65
66
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
67
    from .file_utils import is_apex_available
68
69
70
71
72
73
74

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
75

76
77
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
78

79
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
80
81
82
83
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

84
if is_tensorboard_available():
Julien Chaumond's avatar
Julien Chaumond committed
85
    try:
86
        from torch.utils.tensorboard import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
87
    except ImportError:
88
        from tensorboardX import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
89

90
if is_wandb_available():
91
92
    import wandb

93
94
if is_comet_available():
    import comet_ml
95

96
97
98
99
100
101
if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

Lysandre Debut's avatar
Lysandre Debut committed
102
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
103
104


105
106
107
108
109
110
111
112
def reissue_pt_warnings(caught_warnings):
    # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
    if len(caught_warnings) > 1:
        for w in caught_warnings:
            if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
                warnings.warn(w.message, w.category)


Julien Chaumond's avatar
Julien Chaumond committed
113
114
115
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
116
    Decorator to make all processes in distributed training wait for each local_master to do something.
117
118
119

    Args:
        local_rank (:obj:`int`): The rank of the local process.
Julien Chaumond's avatar
Julien Chaumond committed
120
121
122
123
124
125
126
127
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
Teven's avatar
Teven committed
160
161
162
        assert (
            len(indices) == self.total_size
        ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
163
164
165

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
Teven's avatar
Teven committed
166
167
        assert (
            len(indices) == self.num_samples
168
        ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
169
170
171
172
173
174
175

        return iter(indices)

    def __len__(self):
        return self.num_samples


Lysandre Debut's avatar
Lysandre Debut committed
176
177
178
179
180
181
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


Julien Chaumond's avatar
Julien Chaumond committed
182
183
184
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
185
186
187
    optimized for 馃 Transformers.

    Args:
188
189
190
191
192
        model (:class:`~transformers.PreTrainedModel`, `optional`):
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
        args (:class:`~transformers.TrainingArguments`, `optional`):
            The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
            with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
193
        data_collator (:obj:`DataCollator`, `optional`):
194
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or
195
196
            :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
            provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
Sylvain Gugger's avatar
Sylvain Gugger committed
197
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
198
            The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
199
            ``model.forward()`` method are automatically removed.
Sylvain Gugger's avatar
Sylvain Gugger committed
200
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
201
             The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
202
            ``model.forward()`` method are automatically removed.
203
204
205
206
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
207
208
209
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
210
211
212
213
214
215
216
217
218
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
219
220
        kwargs:
            Deprecated keyword arguments.
Julien Chaumond's avatar
Julien Chaumond committed
221
222
223
224
    """

    def __init__(
        self,
225
226
        model: PreTrainedModel = None,
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
227
228
229
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
230
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
231
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
232
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
233
        tb_writer: Optional["SummaryWriter"] = None,
234
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
235
        **kwargs,
Julien Chaumond's avatar
Julien Chaumond committed
236
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
237
238
239
240
241
242
        if args is None:
            logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
            args = TrainingArguments("tmp_trainer")
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
243
244
245
246
247
248
        assert (
            model is not None or model_init is not None
        ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
        if model is None and model_init is not None:
            model = model_init()
        self.model = model.to(args.device) if model is not None else None
249
250
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
251
252
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
253
        self.tokenizer = tokenizer
254
        self.model_init = model_init
Julien Chaumond's avatar
Julien Chaumond committed
255
        self.compute_metrics = compute_metrics
256
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
257
258
259
260
261
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
262
        self.tb_writer = tb_writer
263
        self.log_history = []
264
265
266
267
268
269
270
271
        if "prediction_loss_only" in kwargs:
            warnings.warn(
                "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
                FutureWarning,
            )
            self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

272
        if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
273
274
275
276
277
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
278
279
280
281

        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

Julien Chaumond's avatar
Julien Chaumond committed
282
        # Create output directory if needed
283
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
284
            os.makedirs(self.args.output_dir, exist_ok=True)
285
        if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
Lysandre Debut's avatar
Lysandre Debut committed
286
287
288
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
289
290
291
292
293
294
295
296
297
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
298

299
300
        if is_datasets_available():
            if isinstance(train_dataset, datasets.Dataset):
301
                self._remove_unused_columns(self.train_dataset, description="training")
302
            if isinstance(eval_dataset, datasets.Dataset):
303
304
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

305
306
        self.global_step = None
        self.epoch = None
307
        self.total_flos = None
308
309
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()
310
        self.hp_search_backend = None
311
        self.use_tune_checkpoints = False
Sylvain Gugger's avatar
Sylvain Gugger committed
312
313
314
315
316
317
        if self.args.label_names is None:
            self.args.label_names = (
                ["start_positions, end_positions"]
                if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
                else ["labels"]
            )
Julien Chaumond's avatar
Julien Chaumond committed
318

319
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
320
321
        if not self.args.remove_unused_columns:
            return
322
323
324
325
326
327
328
329
330
331
332
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
sgugger's avatar
sgugger committed
333
        dataset.set_format(type=dataset.format["type"], columns=columns)
334

335
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
336
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
337
            return None
338
        elif is_torch_tpu_available():
339
            return get_tpu_sampler(self.train_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
340
        else:
341
            return (
Lysandre Debut's avatar
Lysandre Debut committed
342
343
344
345
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
        (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
361
362
363
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
364
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
365
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
366
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
367
368
        )

369
370
371
372
373
374
375
376
377
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
378

Julien Chaumond's avatar
Julien Chaumond committed
379
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
380
381
382
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

383
384
385
386
387
        Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

388
        Args:
389
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
390
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
391
                accepted by the ``model.forward()`` method are automatically removed.
392
        """
Julien Chaumond's avatar
Julien Chaumond committed
393
394
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
395
        elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
396
            self._remove_unused_columns(eval_dataset, description="evaluation")
397
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
398
        eval_sampler = self._get_eval_sampler(eval_dataset)
399

400
        return DataLoader(
401
            eval_dataset,
402
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
403
            batch_size=self.args.eval_batch_size,
404
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
405
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
406
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
407
408
409
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
410
411
412
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

413
414
415
416
417
        Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

418
        Args:
419
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
420
                The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
421
                ``model.forward()`` method are automatically removed.
422
        """
423
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
424
            self._remove_unused_columns(test_dataset, description="test")
425
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
426

427
428
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
429
            test_dataset,
430
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
431
            batch_size=self.args.eval_batch_size,
432
            collate_fn=self.data_collator,
433
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
434
        )
Lysandre Debut's avatar
Lysandre Debut committed
435

436
    def create_optimizer_and_scheduler(self, num_training_steps: int):
437
438
439
        """
        Setup the optimizer and the learning rate scheduler.

440
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
441
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
442
        """
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
Julien Chaumond's avatar
Julien Chaumond committed
465

466
    def setup_wandb(self):
467
468
469
        """
        Setup the optional Weights & Biases (`wandb`) integration.

470
471
        One can subclass and override this method to customize the setup if needed. Find more information
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
472
473
474
475
476
477
478
479
480

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
481
        """
482
483
484
485
486
487
488
        if hasattr(self, "_setup_wandb"):
            warnings.warn(
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
                FutureWarning,
            )
            return self._setup_wandb()

489
        if self.is_world_process_zero():
490
491
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
492
            )
493
494
495
            combined_dict = {**self.args.to_sanitized_dict()}
            if isinstance(self.model, PreTrainedModel):
                combined_dict = {**self.model.config.to_dict(), **combined_dict}
496
497
498
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
            )
499
500
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
501
502
503
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )
504

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    def setup_comet(self):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE:
                (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME:
                (Optional): str - Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY:
                (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"

        For a number of configurable items in the environment,
        see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
        """
        if self.is_world_master():
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(self.model, framework="transformers")
                experiment._log_parameters(self.args, prefix="args/", framework="transformers")
534
535
                if isinstance(self.model, PreTrainedModel):
                    experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
536

537
    def num_examples(self, dataloader: DataLoader) -> int:
538
        """
539
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
540
        """
541
        return len(dataloader.dataset)
542

543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    def _setup_loggers(self):
        if self._loggers_initialized:
            return
        if is_wandb_available():
            self.setup_wandb()
        elif os.environ.get("WANDB_DISABLED") != "true":
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        if is_comet_available():
            self.setup_comet()
        elif os.environ.get("COMET_MODE") != "DISABLED":
            logger.info(
                "To use comet_ml logging, run `pip/conda install comet_ml` "
                "see https://www.comet.ml/docs/python-sdk/huggingface/"
            )
        self._loggers_initialized = True

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
        if self.hp_search_backend is None or trial is None:
            return
        params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
        self.objective = self.compute_objective(metrics)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
591
592
            if self.global_step % self.args.save_steps == 0:
                self._tune_save_checkpoint()
593
594
            tune.report(objective=self.objective, **metrics)

595
596
597
598
599
600
601
602
603
604
605
    def _tune_save_checkpoint(self):
        if not self.use_tune_checkpoints:
            return
        with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
            self.args.output_dir = checkpoint_dir
            output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
            self.save_model(output_dir)
            if self.is_world_master():
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

606
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
607
608
609
610
        """
        Main training entry point.

        Args:
611
612
613
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
614
615
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
Julien Chaumond's avatar
Julien Chaumond committed
616
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
617
618
619
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

620
621
        # Model re-init
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
622
623
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
624
625
626
            model = self.model_init()
            self.model = model.to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
627
628
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
629
630

        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
631
        train_dataloader = self.get_train_dataloader()
632
633
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
Julien Chaumond's avatar
Julien Chaumond committed
634
635
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
636
637
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
Julien Chaumond's avatar
Julien Chaumond committed
638
639
            )
        else:
640
            t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
Julien Chaumond's avatar
Julien Chaumond committed
641
            num_train_epochs = self.args.num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
642
            self.args.max_steps = t_total
Julien Chaumond's avatar
Julien Chaumond committed
643

644
        self.create_optimizer_and_scheduler(num_training_steps=t_total)
645
        self.state = TrainerState()
Julien Chaumond's avatar
Julien Chaumond committed
646
647
648
649
650
651
652
653

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
654
            self.optimizer.load_state_dict(
655
656
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
657
658
659
            with warnings.catch_warnings(record=True) as caught_warnings:
                self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
            reissue_pt_warnings(caught_warnings)
Julien Chaumond's avatar
Julien Chaumond committed
660

661
662
663
664
        # Check if a saved Trainer state exist
        if model_path is not None and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
            self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))

Julien Chaumond's avatar
Julien Chaumond committed
665
        model = self.model
666
        if self.args.fp16 and _use_apex:
Julien Chaumond's avatar
Julien Chaumond committed
667
668
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
669
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Julien Chaumond's avatar
Julien Chaumond committed
670
671
672
673
674
675
676
677
678
679
680

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
681
682
683
684
685
                find_unused_parameters=(
                    not getattr(model.config, "gradient_checkpointing", False)
                    if isinstance(model, PreTrainedModel)
                    else True
                ),
Julien Chaumond's avatar
Julien Chaumond committed
686
            )
687
688
        # find_unused_parameters breaks checkpointing as per
        # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
Julien Chaumond's avatar
Julien Chaumond committed
689
690
691

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
692
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
Julien Chaumond's avatar
Julien Chaumond committed
693
694

        # Train!
695
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
696
697
698
699
700
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
701
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
702
            )
Julien Chaumond's avatar
Julien Chaumond committed
703
        logger.info("***** Running training *****")
704
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
Julien Chaumond's avatar
Julien Chaumond committed
705
        logger.info("  Num Epochs = %d", num_train_epochs)
706
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
Lysandre Debut's avatar
Lysandre Debut committed
707
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
708
709
710
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

711
712
        self.global_step = 0
        self.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
713
714
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
715

Julien Chaumond's avatar
Julien Chaumond committed
716
717
718
719
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
720
                self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
721

722
723
                epochs_trained = self.global_step // num_update_steps_per_epoch
                steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
Julien Chaumond's avatar
Julien Chaumond committed
724
725
726

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
727
                logger.info("  Continuing training from global step %d", self.global_step)
Julien Chaumond's avatar
Julien Chaumond committed
728
729
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
730
                self.global_step = 0
Julien Chaumond's avatar
Julien Chaumond committed
731
732
                logger.info("  Starting fine-tuning.")

733
        tr_loss = torch.tensor(0.0).to(self.args.device)
734
        self.total_flos = self.state.total_flos
735
        logging_loss_scalar = 0.0
Julien Chaumond's avatar
Julien Chaumond committed
736
        model.zero_grad()
737
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
738
739
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
740
741
742
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

743
            if is_torch_tpu_available():
744
745
746
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
747
                epoch_iterator = parallel_loader
748
            else:
749
                epoch_iterator = train_dataloader
750

751
752
753
754
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

755
            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
Julien Chaumond's avatar
Julien Chaumond committed
756
757
758
759
760
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
761
                    epoch_pbar.update(1)
Julien Chaumond's avatar
Julien Chaumond committed
762
763
                    continue

764
                tr_loss += self.training_step(model, inputs)
765
                self.total_flos += self.floating_point_ops(inputs)
Julien Chaumond's avatar
Julien Chaumond committed
766
767
768
769
770
771

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
772
                    if self.args.fp16 and _use_native_amp:
773
                        self.scaler.unscale_(self.optimizer)
774
775
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
776
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
Julien Chaumond's avatar
Julien Chaumond committed
777
778
779
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

780
                    if is_torch_tpu_available():
781
                        xm.optimizer_step(self.optimizer)
782
                    elif self.args.fp16 and _use_native_amp:
783
                        self.scaler.step(self.optimizer)
784
                        self.scaler.update()
Lysandre Debut's avatar
Lysandre Debut committed
785
                    else:
786
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
787

788
                    self.lr_scheduler.step()
Julien Chaumond's avatar
Julien Chaumond committed
789
                    model.zero_grad()
790
791
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)
Julien Chaumond's avatar
Julien Chaumond committed
792

793
794
795
796
                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
797
798
                        tr_loss_scalar = tr_loss.item()
                        logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
799
800
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
801
                            self.lr_scheduler.get_last_lr()[0]
802
                            if version.parse(torch.__version__) >= version.parse("1.4")
803
                            else self.lr_scheduler.get_lr()[0]
804
                        )
805
                        logging_loss_scalar = tr_loss_scalar
806

807
                        self.log(logs)
808

809
810
811
812
                    if (
                        self.args.evaluation_strategy == EvaluationStrategy.STEPS
                        and self.global_step % self.args.eval_steps == 0
                    ):
813
814
                        metrics = self.evaluate()
                        self._report_to_hp_search(trial, epoch, metrics)
815
816
                        if self.args.load_best_model_at_end:
                            self._save_training(model, trial, metrics=metrics)
817

818
819
820
821
822
823
                    if (
                        not self.args.load_best_model_at_end
                        and self.args.save_steps > 0
                        and self.global_step % self.args.save_steps == 0
                    ):
                        self._save_training(model, trial)
Julien Chaumond's avatar
Julien Chaumond committed
824

825
                epoch_pbar.update(1)
Sylvain Gugger's avatar
Sylvain Gugger committed
826
                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
827
                    break
828
829
            epoch_pbar.close()
            train_pbar.update(1)
830
831
832
833

            if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
                metrics = self.evaluate()
                self._report_to_hp_search(trial, epoch, metrics)
834
835
                if self.args.load_best_model_at_end:
                    self._save_training(model, trial, metrics=metrics)
836

837
            if self.args.tpu_metrics_debug or self.args.debug:
838
839
840
841
842
843
844
845
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
846
847
            if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                break
Julien Chaumond's avatar
Julien Chaumond committed
848

849
        train_pbar.close()
Julien Chaumond's avatar
Julien Chaumond committed
850
851
        if self.tb_writer:
            self.tb_writer.close()
852
853
854
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
855
856

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
857
858
859
860
861
862
863
864
865
866
867
        if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            logger.info(
                f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
            )
            if isinstance(model, PreTrainedModel):
                self.model = model.from_pretrained(self.state.best_model_checkpoint)
                self.model = self.model.to(self.args.device)
            else:
                state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
                self.model.load_state_dict(state_dict)

868
        return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
869

870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
    def _save_training(self, model, trial, metrics=None):
        # In all cases (even distributed/parallel), self.model is always a reference
        # to the model we want to save.
        if hasattr(model, "module"):
            assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
        else:
            assert model is self.model, f"Model {model} should be a reference to self.model"
        # Save model checkpoint
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
        if self.hp_search_backend is not None and trial is not None:
            run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
            checkpoint_folder += f"-run-{run_id}"
        output_dir = os.path.join(self.args.output_dir, checkpoint_folder)

        self.store_flos()
        self.save_model(output_dir)

        # Save optimizer and scheduler
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                reissue_pt_warnings(caught_warnings)
        elif self.is_world_process_zero():
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
            reissue_pt_warnings(caught_warnings)

        # Determine the new best metric / best model checkpoint
        if metrics is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.is_world_process_zero():
            self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))

        # Maybe delete some older checkpoints.
        if self.is_world_process_zero():
            self._rotate_checkpoints(use_mtime=True)

924
925
926
927
928
929
930
931
932
933
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
        **kwargs
    ) -> BestRun:
        """
934
935
936
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
        :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
937

Sylvain Gugger's avatar
Sylvain Gugger committed
938
939
940
941
942
943
944
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

966
                - the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
                - the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__

        Returns:
            :class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
982
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
983
984
        if backend == HPSearchBackend.RAY and not is_ray_available():
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
985
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
986
987
988
            )
        self.hp_search_backend = backend

Sylvain Gugger's avatar
Sylvain Gugger committed
989
990
991
992
993
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

994
995
996
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

997
998
        run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
        best_run = run_hp_search(self, n_trials, direction, **kwargs)
999
1000
1001
1002

        self.hp_search_backend = None
        return best_run

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
    def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
1015
1016
1017
        # Set up loggers like W&B or Comet ML
        self._setup_loggers()

1018
1019
1020
1021
1022
1023
1024
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

1025
1026
        if self.epoch is not None:
            logs["epoch"] = self.epoch
1027
1028
1029
1030
1031
1032
        if self.total_flos is not None:
            if self.args.local_rank != -1:
                total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
            else:
                total_flos = self.total_flos
            if total_flos > 0:
1033
                logs["total_flos"] = total_flos
1034
1035
1036
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
1037
1038
        if self.tb_writer:
            for k, v in logs.items():
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
1051
            self.tb_writer.flush()
1052
        if is_wandb_available():
1053
            if self.is_world_process_zero():
1054
                wandb.log(logs, step=self.global_step)
1055
1056
1057
1058
1059
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
1060
        output = {**logs, **{"step": self.global_step}}
1061
1062
        if self.is_world_process_zero():
            self.log_history.append(output)
1063
1064
1065
        if iterator is not None:
            iterator.write(output)
        else:
1066
            print(output)
Julien Chaumond's avatar
Julien Chaumond committed
1067

sgugger's avatar
Fix CI  
sgugger committed
1068
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
1069
1070
1071
1072
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
1073
        for k, v in inputs.items():
1074
1075
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
1076

1077
1078
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
1079

1080
1081
        return inputs

1082
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
1083
        """
1084
        Perform a training step on a batch of inputs.
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
1098
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
1099
1100
1101
1102
1103
1104
        """
        if hasattr(self, "_training_step"):
            warnings.warn(
                "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
                FutureWarning,
            )
1105
            return self._training_step(model, inputs, self.optimizer)
1106
1107

        model.train()
1108
        inputs = self._prepare_inputs(inputs)
1109

1110
1111
        if self.args.fp16 and _use_native_amp:
            with autocast():
Sylvain Gugger's avatar
Sylvain Gugger committed
1112
                loss = self.compute_loss(model, inputs)
1113
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1114
            loss = self.compute_loss(model, inputs)
1115

Julien Chaumond's avatar
Julien Chaumond committed
1116
1117
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1118

Julien Chaumond's avatar
Julien Chaumond committed
1119
1120
1121
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

1122
1123
1124
        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
1125
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1126
1127
1128
1129
                scaled_loss.backward()
        else:
            loss.backward()

1130
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
1131

Sylvain Gugger's avatar
Sylvain Gugger committed
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
    def compute_loss(self, model, inputs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)
        # Save past state if it exists
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        return outputs[0]

Lysandre Debut's avatar
Lysandre Debut committed
1145
    def is_local_master(self) -> bool:
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
        return self.is_local_process_zero()

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
1162
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1163
1164
1165
1166
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

Julien Chaumond's avatar
Julien Chaumond committed
1167
1168
    def is_world_master(self) -> bool:
        """
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
        return self.is_world_process_zero()

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1183
        """
1184
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1185
1186
1187
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
1188
1189
1190

    def save_model(self, output_dir: Optional[str] = None):
        """
1191
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1192

1193
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
1194
        """
1195

1196
        if is_torch_tpu_available():
1197
            self._save_tpu(output_dir)
1198
        elif self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
1199
1200
            self._save(output_dir)

1201
1202
1203
1204
1205
1206
1207
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1208
1209
1210
            json.dump(
                self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
            )
1211
1212
1213
1214

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
1215
1216
1217
1218
1219
1220
        if not isinstance(self.model, PreTrainedModel):
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
1221
1222
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
1223

Julien Chaumond's avatar
Julien Chaumond committed
1224
1225
1226
1227
1228
1229
1230
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
1231
1232
1233
1234
1235
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
1236
1237
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1238
1239
1240

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1241
1242
1243
1244
        json.dump(
            self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
        )

1245
    def store_flos(self):
1246
1247
1248
        # Storing the number of floating-point operations that went into the model
        if self.total_flos is not None:
            if self.args.local_rank != -1:
1249
                self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
1250
            else:
1251
                self.state.total_flos = self.total_flos
Julien Chaumond's avatar
Julien Chaumond committed
1252
1253
1254
1255

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

1256
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1257
1258
1259
1260
1261

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1262
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1263
1264
1265
1266
1267
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
1268
1269
1270
1271
1272
1273
1274
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
            best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
            checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = (
                checkpoints_sorted[-1],
                checkpoints_sorted[best_model_index],
            )
Julien Chaumond's avatar
Julien Chaumond committed
1275
1276
1277
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
1278
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

1292
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1293
        """
1294
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1295
1296

        The calling script will be responsible for providing a method to compute metrics, as they are
1297
        task-dependent (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1298

1299
1300
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1301
        Args:
1302
            eval_dataset (:obj:`Dataset`, `optional`):
1303
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
1304
                columns not accepted by the ``model.forward()`` method are automatically removed.
1305

Julien Chaumond's avatar
Julien Chaumond committed
1306
        Returns:
1307
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
Julien Chaumond's avatar
Julien Chaumond committed
1308
1309
1310
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

1311
        output = self.prediction_loop(eval_dataloader, description="Evaluation")
Lysandre Debut's avatar
Lysandre Debut committed
1312

1313
        self.log(output.metrics)
1314

1315
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1316
1317
1318
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Julien Chaumond's avatar
Julien Chaumond committed
1319
1320
1321
1322
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
1323
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1324
1325

        Depending on the dataset and your use case, your test dataset may contain labels.
1326
1327
1328
1329
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
1330
                Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
1331
                ``model.forward()`` method are automatically removed.
1332

1333
1334
1335
1336
1337
1338
1339
1340
        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1341
1342
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
1343

1344
        return self.prediction_loop(test_dataloader, description="Prediction")
Julien Chaumond's avatar
Julien Chaumond committed
1345

1346
    def prediction_loop(
Julien Chaumond's avatar
Julien Chaumond committed
1347
1348
1349
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
1350
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1351
1352
1353

        Works both with or without labels.
        """
1354
1355
1356
1357
1358
1359
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
Julien Chaumond's avatar
Julien Chaumond committed
1360

1361
1362
1363
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1364

1365
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
1366
        # multi-gpu eval
1367
1368
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
Julien Chaumond's avatar
Julien Chaumond committed
1369
1370
        else:
            model = self.model
1371
1372
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
1373

1374
        batch_size = dataloader.batch_size
Julien Chaumond's avatar
Julien Chaumond committed
1375
        logger.info("***** Running %s *****", description)
1376
1377
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
1378
        eval_losses: List[float] = []
1379
1380
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
Julien Chaumond's avatar
Julien Chaumond committed
1381
1382
        model.eval()

1383
        if is_torch_tpu_available():
1384
1385
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1386
        if self.args.past_index >= 0:
1387
            self._past = None
1388

1389
1390
        disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
1391
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
1392
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
1393
            if loss is not None:
1394
                eval_losses.extend([loss] * batch_size)
1395
            if logits is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1396
                preds = logits if preds is None else nested_concat(preds, logits, dim=0)
1397
            if labels is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1398
                label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
1399

1400
1401
1402
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1403

1404
1405
1406
        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1407
                preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
1408
            if label_ids is not None:
1409
                label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
1410
        elif is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1411
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
1412
            if preds is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1413
                preds = nested_xla_mesh_reduce(preds, "eval_preds")
1414
            if label_ids is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1415
                label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
1416
1417
            if eval_losses is not None:
                eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
1418
1419
1420

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1421
            preds = nested_numpify(preds)
1422
        if label_ids is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1423
            label_ids = nested_numpify(label_ids)
Lysandre Debut's avatar
Lysandre Debut committed
1424

Julien Chaumond's avatar
Julien Chaumond committed
1425
1426
1427
1428
1429
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
1430
1431
1432
1433
1434
1435
1436
1437
            if self.args.local_rank != -1:
                metrics["eval_loss"] = (
                    distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
                    .mean()
                    .item()
                )
            else:
                metrics["eval_loss"] = np.mean(eval_losses)
1438
1439
1440
1441
1442

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
1443
1444

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1445

1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1469
        has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
1470
        inputs = self._prepare_inputs(inputs)
1471
1472
1473
1474

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
1475
1476
1477
                # The .mean() is to reduce in case of distributed training
                loss = outputs[0].mean().item()
                logits = outputs[1:]
1478
1479
            else:
                loss = None
1480
1481
                # Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
                logits = outputs[:]
1482
1483
1484
1485
1486
1487
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

Sylvain Gugger's avatar
Sylvain Gugger committed
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        logits = tuple(logit.detach() for logit in logits)
        if len(logits) == 1:
            logits = logits[0]

        if has_labels:
            labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        return (loss, logits, labels)
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
        For models that inherit from :class:`~transformers.PretrainedModel`, uses
        that method to compute the number of floating point operations for every backward + forward pass. If using
        another model, either implement such a method in the model or subclass and override this method.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

        Returns:
            :obj:`int`: The number of floating-point operations.
        """

Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1517
        model = self._actual_model(self.model)
1518
1519
1520
1521
1522
1523

        if hasattr(model, "floating_point_ops"):
            return model.floating_point_ops(inputs)

        else:
            return 0
Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542

    @staticmethod
    def _actual_model(
        model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
    ) -> torch.nn.modules.Module:
        """

        Args:
            model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
                Model object used during training

        Returns:
            :obj:`torch.nn.modules.Module`: unwrapped module
        """
        if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module
        else:
            model = model
        return model