trainer.py 69.8 KB
Newer Older
1
import inspect
2
import math
Julien Chaumond's avatar
Julien Chaumond committed
3
4
5
import os
import re
import shutil
6
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
7
8
from contextlib import contextmanager
from pathlib import Path
9
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
10
11
12

import numpy as np
import torch
13
from packaging import version
Julien Chaumond's avatar
Julien Chaumond committed
14
15
16
17
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
18
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
19
from tqdm.auto import tqdm, trange
Julien Chaumond's avatar
Julien Chaumond committed
20

21
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
22
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
23
24
25
26
27
28
29
from .integrations import (
    default_hp_search_backend,
    is_comet_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
30
31
    run_hp_search_optuna,
    run_hp_search_ray,
32
)
Sylvain Gugger's avatar
Sylvain Gugger committed
33
from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
Julien Chaumond's avatar
Julien Chaumond committed
34
35
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
36
from .tokenization_utils_base import PreTrainedTokenizerBase
37
38
39
40
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
41
    EvaluationStrategy,
42
43
    HPSearchBackend,
    PredictionOutput,
44
    TrainerState,
45
46
47
    TrainOutput,
    default_compute_objective,
    default_hp_space,
48
49
    distributed_broadcast_scalars,
    distributed_concat,
Sylvain Gugger's avatar
Sylvain Gugger committed
50
    nested_concat,
51
    nested_detach,
Sylvain Gugger's avatar
Sylvain Gugger committed
52
53
    nested_numpify,
    nested_xla_mesh_reduce,
54
55
    set_seed,
)
Patrick von Platen's avatar
Patrick von Platen committed
56
from .training_args import TrainingArguments
Lysandre Debut's avatar
Lysandre Debut committed
57
from .utils import logging
Julien Chaumond's avatar
Julien Chaumond committed
58
59


60
61
62
_use_native_amp = False
_use_apex = False

63
64
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."

65
66
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
67
    from .file_utils import is_apex_available
68
69
70
71
72
73
74

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
75

76
77
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
78

79
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
80
81
82
83
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

84
if is_tensorboard_available():
Julien Chaumond's avatar
Julien Chaumond committed
85
    try:
86
        from torch.utils.tensorboard import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
87
    except ImportError:
88
        from tensorboardX import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
89

90
if is_wandb_available():
91
92
    import wandb

93
94
if is_comet_available():
    import comet_ml
95

96
97
98
99
100
101
if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

Lysandre Debut's avatar
Lysandre Debut committed
102
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
103
104


105
106
107
108
109
110
111
112
def reissue_pt_warnings(caught_warnings):
    # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
    if len(caught_warnings) > 1:
        for w in caught_warnings:
            if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
                warnings.warn(w.message, w.category)


Julien Chaumond's avatar
Julien Chaumond committed
113
114
115
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
116
    Decorator to make all processes in distributed training wait for each local_master to do something.
117
118
119

    Args:
        local_rank (:obj:`int`): The rank of the local process.
Julien Chaumond's avatar
Julien Chaumond committed
120
121
122
123
124
125
126
127
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
Teven's avatar
Teven committed
160
161
162
        assert (
            len(indices) == self.total_size
        ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
163
164
165

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
Teven's avatar
Teven committed
166
167
        assert (
            len(indices) == self.num_samples
168
        ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
169
170
171
172
173
174
175

        return iter(indices)

    def __len__(self):
        return self.num_samples


Lysandre Debut's avatar
Lysandre Debut committed
176
177
178
179
180
181
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


Julien Chaumond's avatar
Julien Chaumond committed
182
183
184
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
185
186
187
    optimized for 馃 Transformers.

    Args:
188
189
190
191
192
        model (:class:`~transformers.PreTrainedModel`, `optional`):
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
        args (:class:`~transformers.TrainingArguments`, `optional`):
            The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
            with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
193
        data_collator (:obj:`DataCollator`, `optional`):
194
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or
195
196
            :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
            provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
Sylvain Gugger's avatar
Sylvain Gugger committed
197
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
198
            The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
199
            ``model.forward()`` method are automatically removed.
Sylvain Gugger's avatar
Sylvain Gugger committed
200
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
201
             The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
202
            ``model.forward()`` method are automatically removed.
203
204
205
206
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
207
208
209
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
210
211
212
213
214
215
216
217
218
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
219
220
        kwargs:
            Deprecated keyword arguments.
Julien Chaumond's avatar
Julien Chaumond committed
221
222
223
224
    """

    def __init__(
        self,
225
226
        model: PreTrainedModel = None,
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
227
228
229
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
230
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
231
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
232
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
233
        tb_writer: Optional["SummaryWriter"] = None,
234
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
235
        **kwargs,
Julien Chaumond's avatar
Julien Chaumond committed
236
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
237
238
239
240
241
242
        if args is None:
            logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
            args = TrainingArguments("tmp_trainer")
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
243
244
245
246
247
248
        assert (
            model is not None or model_init is not None
        ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
        if model is None and model_init is not None:
            model = model_init()
        self.model = model.to(args.device) if model is not None else None
249
250
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
251
252
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
253
        self.tokenizer = tokenizer
254
        self.model_init = model_init
Julien Chaumond's avatar
Julien Chaumond committed
255
        self.compute_metrics = compute_metrics
256
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
257
258
259
260
261
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
262
        self.tb_writer = tb_writer
263
264
        if "prediction_loss_only" in kwargs:
            warnings.warn(
265
266
267
                "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
                + "future version. Use `args.prediction_loss_only` instead. Setting "
                + f"`args.prediction_loss_only={kwargs['prediction_loss_only']}",
268
269
270
271
272
                FutureWarning,
            )
            self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

273
        if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
274
275
276
277
278
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
279
280
281
282

        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

Julien Chaumond's avatar
Julien Chaumond committed
283
        # Create output directory if needed
284
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
285
            os.makedirs(self.args.output_dir, exist_ok=True)
286
        if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
Lysandre Debut's avatar
Lysandre Debut committed
287
288
289
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
290
291
292
293
294
295
296
297
298
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
299

300
301
        if is_datasets_available():
            if isinstance(train_dataset, datasets.Dataset):
302
                self._remove_unused_columns(self.train_dataset, description="training")
303
            if isinstance(eval_dataset, datasets.Dataset):
304
305
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

306
307
308
309
        self.state = TrainerState()
        # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
        # state at each call to self.log.
        self._total_flos = None
310
311
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()
312
        self.hp_search_backend = None
313
        self.use_tune_checkpoints = False
314
315
316
317
318
319
        default_label_names = (
            ["start_positions, end_positions"]
            if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
            else ["labels"]
        )
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Julien Chaumond's avatar
Julien Chaumond committed
320

321
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
322
323
        if not self.args.remove_unused_columns:
            return
324
325
326
327
328
329
330
331
332
333
334
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
sgugger's avatar
sgugger committed
335
        dataset.set_format(type=dataset.format["type"], columns=columns)
336

337
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
338
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
339
            return None
340
        elif is_torch_tpu_available():
341
            return get_tpu_sampler(self.train_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
342
        else:
343
            return (
Lysandre Debut's avatar
Lysandre Debut committed
344
345
346
347
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
        (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
363
364
365
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
366
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
367
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
368
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
369
370
        )

371
372
373
374
375
376
377
378
379
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
380

Julien Chaumond's avatar
Julien Chaumond committed
381
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
382
383
384
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

385
386
387
388
389
        Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

390
        Args:
391
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
392
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
393
                accepted by the ``model.forward()`` method are automatically removed.
394
        """
Julien Chaumond's avatar
Julien Chaumond committed
395
396
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
397
        elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
398
            self._remove_unused_columns(eval_dataset, description="evaluation")
399
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
400
        eval_sampler = self._get_eval_sampler(eval_dataset)
401

402
        return DataLoader(
403
            eval_dataset,
404
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
405
            batch_size=self.args.eval_batch_size,
406
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
407
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
408
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
409
410
411
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
412
413
414
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

415
416
417
418
419
        Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

420
        Args:
421
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
422
                The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
423
                ``model.forward()`` method are automatically removed.
424
        """
425
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
426
            self._remove_unused_columns(test_dataset, description="test")
427
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
428

429
430
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
431
            test_dataset,
432
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
433
            batch_size=self.args.eval_batch_size,
434
            collate_fn=self.data_collator,
435
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
436
        )
Lysandre Debut's avatar
Lysandre Debut committed
437

438
    def create_optimizer_and_scheduler(self, num_training_steps: int):
439
440
441
        """
        Setup the optimizer and the learning rate scheduler.

442
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
443
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
444
        """
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
Julien Chaumond's avatar
Julien Chaumond committed
467

468
    def setup_wandb(self):
469
470
471
        """
        Setup the optional Weights & Biases (`wandb`) integration.

472
473
        One can subclass and override this method to customize the setup if needed. Find more information
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
474
475
476
477
478
479
480
481
482

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
483
        """
484
485
486
487
488
489
490
        if hasattr(self, "_setup_wandb"):
            warnings.warn(
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
                FutureWarning,
            )
            return self._setup_wandb()

491
        if self.is_world_process_zero():
492
493
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
494
            )
495
496
497
            combined_dict = {**self.args.to_sanitized_dict()}
            if isinstance(self.model, PreTrainedModel):
                combined_dict = {**self.model.config.to_dict(), **combined_dict}
498
499
500
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
            )
501
502
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
503
504
505
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )
506

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    def setup_comet(self):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE:
                (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME:
                (Optional): str - Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY:
                (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"

        For a number of configurable items in the environment,
        see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
        """
        if self.is_world_master():
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(self.model, framework="transformers")
                experiment._log_parameters(self.args, prefix="args/", framework="transformers")
536
537
                if isinstance(self.model, PreTrainedModel):
                    experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
538

539
    def num_examples(self, dataloader: DataLoader) -> int:
540
        """
541
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
542
        """
543
        return len(dataloader.dataset)
544

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    def _setup_loggers(self):
        if self._loggers_initialized:
            return
        if is_wandb_available():
            self.setup_wandb()
        elif os.environ.get("WANDB_DISABLED") != "true":
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        if is_comet_available():
            self.setup_comet()
        elif os.environ.get("COMET_MODE") != "DISABLED":
            logger.info(
                "To use comet_ml logging, run `pip/conda install comet_ml` "
                "see https://www.comet.ml/docs/python-sdk/huggingface/"
            )
        self._loggers_initialized = True

564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
        if self.hp_search_backend is None or trial is None:
            return
        params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
        self.objective = self.compute_objective(metrics)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
593
            if self.state.global_step % self.args.save_steps == 0:
594
                self._tune_save_checkpoint()
595
596
            tune.report(objective=self.objective, **metrics)

597
598
599
    def _tune_save_checkpoint(self):
        if not self.use_tune_checkpoints:
            return
600
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
601
            self.args.output_dir = checkpoint_dir
602
            output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
603
604
            self.save_model(output_dir)
            if self.is_world_master():
605
                self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
606
607
608
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

609
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
610
611
612
613
        """
        Main training entry point.

        Args:
614
615
616
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
617
618
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
Julien Chaumond's avatar
Julien Chaumond committed
619
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
620
621
622
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

623
624
        # Model re-init
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
625
626
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
627
628
629
            model = self.model_init()
            self.model = model.to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
630
631
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
632
633

        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
634
        train_dataloader = self.get_train_dataloader()
635
636
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
Julien Chaumond's avatar
Julien Chaumond committed
637
        if self.args.max_steps > 0:
638
            max_steps = self.args.max_steps
639
640
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
Julien Chaumond's avatar
Julien Chaumond committed
641
642
            )
        else:
643
            max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
Julien Chaumond's avatar
Julien Chaumond committed
644
            num_train_epochs = self.args.num_train_epochs
645
        num_train_epochs = int(np.ceil(num_train_epochs))
Julien Chaumond's avatar
Julien Chaumond committed
646

647
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)
648
        self.state = TrainerState()
Julien Chaumond's avatar
Julien Chaumond committed
649
650
651
652
653
654
655
656

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
657
            self.optimizer.load_state_dict(
658
659
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
660
661
662
            with warnings.catch_warnings(record=True) as caught_warnings:
                self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
            reissue_pt_warnings(caught_warnings)
Julien Chaumond's avatar
Julien Chaumond committed
663

664
        # Moxed precision training with apex (torch < 1.6)
Julien Chaumond's avatar
Julien Chaumond committed
665
        model = self.model
666
        if self.args.fp16 and _use_apex:
Julien Chaumond's avatar
Julien Chaumond committed
667
668
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
669
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Julien Chaumond's avatar
Julien Chaumond committed
670

671
        # Multi-gpu training (should be after apex fp16 initialization)
Julien Chaumond's avatar
Julien Chaumond committed
672
673
674
675
676
677
678
679
680
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
681
682
683
684
685
                find_unused_parameters=(
                    not getattr(model.config, "gradient_checkpointing", False)
                    if isinstance(model, PreTrainedModel)
                    else True
                ),
Julien Chaumond's avatar
Julien Chaumond committed
686
            )
687
688
        # find_unused_parameters breaks checkpointing as per
        # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
Julien Chaumond's avatar
Julien Chaumond committed
689
690
691

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
692
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
Julien Chaumond's avatar
Julien Chaumond committed
693
694

        # Train!
695
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
696
697
698
699
700
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
701
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
702
            )
Julien Chaumond's avatar
Julien Chaumond committed
703
        logger.info("***** Running training *****")
704
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
Julien Chaumond's avatar
Julien Chaumond committed
705
        logger.info("  Num Epochs = %d", num_train_epochs)
706
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
Lysandre Debut's avatar
Lysandre Debut committed
707
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
708
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
709
        logger.info("  Total optimization steps = %d", max_steps)
Julien Chaumond's avatar
Julien Chaumond committed
710

711
        self.state.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
712
713
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
714

Julien Chaumond's avatar
Julien Chaumond committed
715
        # Check if continuing training from a checkpoint
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
            self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
            steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", self.state.global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Julien Chaumond's avatar
Julien Chaumond committed
730

731
        tr_loss = torch.tensor(0.0).to(self.args.device)
732
        self._total_flos = self.state.total_flos
733
        logging_loss_scalar = 0.0
Julien Chaumond's avatar
Julien Chaumond committed
734
        model.zero_grad()
735
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
736
737
        train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, num_train_epochs):
738
739
740
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

741
            if is_torch_tpu_available():
742
743
744
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
745
                epoch_iterator = parallel_loader
746
            else:
747
                epoch_iterator = train_dataloader
748

749
750
751
752
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

753
            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
Julien Chaumond's avatar
Julien Chaumond committed
754
755
756
757
758
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
759
                    epoch_pbar.update(1)
Julien Chaumond's avatar
Julien Chaumond committed
760
761
                    continue

762
                tr_loss += self.training_step(model, inputs)
763
                self._total_flos += self.floating_point_ops(inputs)
Julien Chaumond's avatar
Julien Chaumond committed
764
765
766
767
768
769

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
770
                    if self.args.fp16 and _use_native_amp:
771
                        self.scaler.unscale_(self.optimizer)
772
773
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
774
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
Julien Chaumond's avatar
Julien Chaumond committed
775
776
777
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

778
                    if is_torch_tpu_available():
779
                        xm.optimizer_step(self.optimizer)
780
                    elif self.args.fp16 and _use_native_amp:
781
                        self.scaler.step(self.optimizer)
782
                        self.scaler.update()
Lysandre Debut's avatar
Lysandre Debut committed
783
                    else:
784
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
785

786
                    self.lr_scheduler.step()
Julien Chaumond's avatar
Julien Chaumond committed
787
                    model.zero_grad()
788
789
                    self.state.global_step += 1
                    self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
Julien Chaumond's avatar
Julien Chaumond committed
790

791
792
                    if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
                        self.state.global_step == 1 and self.args.logging_first_step
793
794
                    ):
                        logs: Dict[str, float] = {}
795
796
                        tr_loss_scalar = tr_loss.item()
                        logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
797
798
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
799
                            self.lr_scheduler.get_last_lr()[0]
800
                            if version.parse(torch.__version__) >= version.parse("1.4")
801
                            else self.lr_scheduler.get_lr()[0]
802
                        )
803
                        logging_loss_scalar = tr_loss_scalar
804

805
                        self.log(logs)
806

807
808
                    if (
                        self.args.evaluation_strategy == EvaluationStrategy.STEPS
809
                        and self.state.global_step % self.args.eval_steps == 0
810
                    ):
811
812
                        metrics = self.evaluate()
                        self._report_to_hp_search(trial, epoch, metrics)
813
814
                        if self.args.load_best_model_at_end:
                            self._save_training(model, trial, metrics=metrics)
815

816
817
818
                    if (
                        not self.args.load_best_model_at_end
                        and self.args.save_steps > 0
819
                        and self.state.global_step % self.args.save_steps == 0
820
821
                    ):
                        self._save_training(model, trial)
Julien Chaumond's avatar
Julien Chaumond committed
822

823
                epoch_pbar.update(1)
824
                if self.state.global_step >= max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
825
                    break
826
827
            epoch_pbar.close()
            train_pbar.update(1)
828
829
830
831

            if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
                metrics = self.evaluate()
                self._report_to_hp_search(trial, epoch, metrics)
832
833
                if self.args.load_best_model_at_end:
                    self._save_training(model, trial, metrics=metrics)
834

835
            if self.args.tpu_metrics_debug or self.args.debug:
836
837
838
839
840
841
842
843
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
844
            if self.state.global_step >= max_steps:
845
                break
Julien Chaumond's avatar
Julien Chaumond committed
846

847
        train_pbar.close()
Julien Chaumond's avatar
Julien Chaumond committed
848
849
        if self.tb_writer:
            self.tb_writer.close()
850
851
852
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
853
854

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
855
856
857
858
859
860
861
862
863
864
865
        if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            logger.info(
                f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
            )
            if isinstance(model, PreTrainedModel):
                self.model = model.from_pretrained(self.state.best_model_checkpoint)
                self.model = self.model.to(self.args.device)
            else:
                state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
                self.model.load_state_dict(state_dict)

866
        return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
867

868
869
870
871
872
873
874
875
    def _save_training(self, model, trial, metrics=None):
        # In all cases (even distributed/parallel), self.model is always a reference
        # to the model we want to save.
        if hasattr(model, "module"):
            assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
        else:
            assert model is self.model, f"Model {model} should be a reference to self.model"
        # Save model checkpoint
876
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
        if self.hp_search_backend is not None and trial is not None:
            run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
            checkpoint_folder += f"-run-{run_id}"
        output_dir = os.path.join(self.args.output_dir, checkpoint_folder)

        self.store_flos()
        self.save_model(output_dir)

        # Save optimizer and scheduler
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                reissue_pt_warnings(caught_warnings)
        elif self.is_world_process_zero():
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
            reissue_pt_warnings(caught_warnings)

        # Determine the new best metric / best model checkpoint
        if metrics is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.is_world_process_zero():
            self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))

        # Maybe delete some older checkpoints.
        if self.is_world_process_zero():
            self._rotate_checkpoints(use_mtime=True)

922
923
924
925
926
927
928
929
930
931
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
        **kwargs
    ) -> BestRun:
        """
932
933
934
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
        :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
935

Sylvain Gugger's avatar
Sylvain Gugger committed
936
937
938
939
940
941
942
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

964
                - the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
                - the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__

        Returns:
            :class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
980
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
981
982
        if backend == HPSearchBackend.RAY and not is_ray_available():
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
983
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
984
985
986
            )
        self.hp_search_backend = backend

Sylvain Gugger's avatar
Sylvain Gugger committed
987
988
989
990
991
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

992
993
994
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

995
996
        run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
        best_run = run_hp_search(self, n_trials, direction, **kwargs)
997
998
999
1000

        self.hp_search_backend = None
        return best_run

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
1013
1014
1015
        # Set up loggers like W&B or Comet ML
        self._setup_loggers()

1016
1017
1018
1019
1020
1021
1022
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

1023
1024
1025
1026
1027
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
        if self._total_flos is not None:
            self.store_flos()
            logs["total_flos"] = self.state.total_flos
1028
1029
        if self.tb_writer:
            for k, v in logs.items():
1030
                if isinstance(v, (int, float)):
1031
                    self.tb_writer.add_scalar(k, v, self.state.global_step)
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
1042
            self.tb_writer.flush()
1043
        if is_wandb_available():
1044
            if self.is_world_process_zero():
1045
                wandb.log(logs, step=self.state.global_step)
1046
1047
1048
1049
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
1050
1051
1052
1053
1054
                    experiment._log_metrics(
                        logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
                    )
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
1055
1056
1057
        if iterator is not None:
            iterator.write(output)
        else:
1058
            print(output)
Julien Chaumond's avatar
Julien Chaumond committed
1059

sgugger's avatar
Fix CI  
sgugger committed
1060
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
1061
1062
1063
1064
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
1065
        for k, v in inputs.items():
1066
1067
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
1068

1069
1070
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
1071

1072
1073
        return inputs

1074
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
1075
        """
1076
        Perform a training step on a batch of inputs.
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
1090
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
1091
1092
1093
1094
1095
1096
        """
        if hasattr(self, "_training_step"):
            warnings.warn(
                "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
                FutureWarning,
            )
1097
            return self._training_step(model, inputs, self.optimizer)
1098
1099

        model.train()
1100
        inputs = self._prepare_inputs(inputs)
1101

1102
1103
        if self.args.fp16 and _use_native_amp:
            with autocast():
Sylvain Gugger's avatar
Sylvain Gugger committed
1104
                loss = self.compute_loss(model, inputs)
1105
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1106
            loss = self.compute_loss(model, inputs)
1107

Julien Chaumond's avatar
Julien Chaumond committed
1108
1109
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1110

Julien Chaumond's avatar
Julien Chaumond committed
1111
1112
1113
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

1114
1115
1116
        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
1117
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1118
1119
1120
1121
                scaled_loss.backward()
        else:
            loss.backward()

1122
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
1123

Sylvain Gugger's avatar
Sylvain Gugger committed
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
    def compute_loss(self, model, inputs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)
        # Save past state if it exists
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        return outputs[0]

Lysandre Debut's avatar
Lysandre Debut committed
1137
    def is_local_master(self) -> bool:
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
        return self.is_local_process_zero()

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
1154
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1155
1156
1157
1158
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

Julien Chaumond's avatar
Julien Chaumond committed
1159
1160
    def is_world_master(self) -> bool:
        """
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
        return self.is_world_process_zero()

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1175
        """
1176
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1177
1178
1179
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
1180
1181
1182

    def save_model(self, output_dir: Optional[str] = None):
        """
1183
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1184

1185
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
1186
        """
1187

1188
        if is_torch_tpu_available():
1189
            self._save_tpu(output_dir)
1190
        elif self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
1191
1192
            self._save(output_dir)

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
1204
1205
1206
1207
1208
1209
        if not isinstance(self.model, PreTrainedModel):
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
1210
1211
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
1212

Julien Chaumond's avatar
Julien Chaumond committed
1213
1214
1215
1216
1217
1218
1219
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
1220
1221
1222
1223
1224
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
1225
1226
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1227
1228
1229

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1230

1231
    def store_flos(self):
1232
        # Storing the number of floating-point operations that went into the model
1233
        if self._total_flos is not None:
1234
            if self.args.local_rank != -1:
1235
                self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
1236
            else:
1237
                self.state.total_flos = self._total_flos
Julien Chaumond's avatar
Julien Chaumond committed
1238
1239
1240
1241

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

1242
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1243
1244
1245
1246
1247

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1248
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1249
1250
1251
1252
1253
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
1254
1255
1256
1257
1258
1259
1260
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
            best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
            checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = (
                checkpoints_sorted[-1],
                checkpoints_sorted[best_model_index],
            )
Julien Chaumond's avatar
Julien Chaumond committed
1261
1262
1263
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
1264
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

1278
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1279
        """
1280
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1281
1282

        The calling script will be responsible for providing a method to compute metrics, as they are
1283
        task-dependent (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1284

1285
1286
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1287
        Args:
1288
            eval_dataset (:obj:`Dataset`, `optional`):
1289
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
1290
                columns not accepted by the ``model.forward()`` method are automatically removed.
1291

Julien Chaumond's avatar
Julien Chaumond committed
1292
        Returns:
1293
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
Julien Chaumond's avatar
Julien Chaumond committed
1294
1295
1296
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

1297
        output = self.prediction_loop(eval_dataloader, description="Evaluation")
Lysandre Debut's avatar
Lysandre Debut committed
1298

1299
        self.log(output.metrics)
1300

1301
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1302
1303
1304
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Julien Chaumond's avatar
Julien Chaumond committed
1305
1306
1307
1308
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
1309
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1310
1311

        Depending on the dataset and your use case, your test dataset may contain labels.
1312
1313
1314
1315
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
1316
                Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
1317
                ``model.forward()`` method are automatically removed.
1318

1319
1320
1321
1322
1323
1324
1325
1326
        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1327
1328
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
1329

1330
        return self.prediction_loop(test_dataloader, description="Prediction")
Julien Chaumond's avatar
Julien Chaumond committed
1331

1332
    def prediction_loop(
Julien Chaumond's avatar
Julien Chaumond committed
1333
1334
1335
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
1336
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1337
1338
1339

        Works both with or without labels.
        """
1340
1341
1342
1343
1344
1345
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
Julien Chaumond's avatar
Julien Chaumond committed
1346

1347
1348
1349
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1350

1351
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
1352
        # multi-gpu eval
1353
1354
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
Julien Chaumond's avatar
Julien Chaumond committed
1355
1356
        else:
            model = self.model
1357
1358
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
1359

1360
        batch_size = dataloader.batch_size
Julien Chaumond's avatar
Julien Chaumond committed
1361
        logger.info("***** Running %s *****", description)
1362
1363
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
1364
        eval_losses: List[float] = []
1365
1366
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
Julien Chaumond's avatar
Julien Chaumond committed
1367
1368
        model.eval()

1369
        if is_torch_tpu_available():
1370
1371
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1372
        if self.args.past_index >= 0:
1373
            self._past = None
1374

1375
1376
        disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
1377
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
1378
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
1379
            if loss is not None:
1380
                eval_losses.extend([loss] * batch_size)
1381
            if logits is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1382
                preds = logits if preds is None else nested_concat(preds, logits, dim=0)
1383
            if labels is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1384
                label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
1385

1386
1387
1388
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1389

1390
1391
1392
        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1393
                preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
1394
            if label_ids is not None:
1395
                label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
1396
        elif is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1397
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
1398
            if preds is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1399
                preds = nested_xla_mesh_reduce(preds, "eval_preds")
1400
            if label_ids is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1401
                label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
1402
1403
            if eval_losses is not None:
                eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
1404
1405
1406

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1407
            preds = nested_numpify(preds)
1408
        if label_ids is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1409
            label_ids = nested_numpify(label_ids)
Lysandre Debut's avatar
Lysandre Debut committed
1410

Julien Chaumond's avatar
Julien Chaumond committed
1411
1412
1413
1414
1415
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
1416
1417
1418
1419
1420
1421
1422
1423
            if self.args.local_rank != -1:
                metrics["eval_loss"] = (
                    distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
                    .mean()
                    .item()
                )
            else:
                metrics["eval_loss"] = np.mean(eval_losses)
1424
1425
1426
1427
1428

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
1429
1430

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1431

1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
1455
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
1456
        inputs = self._prepare_inputs(inputs)
1457
1458
1459
1460

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
1461
1462
1463
                # The .mean() is to reduce in case of distributed training
                loss = outputs[0].mean().item()
                logits = outputs[1:]
1464
1465
            else:
                loss = None
1466
1467
                # Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
                logits = outputs[:]
1468
1469
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
1470
1471
                # Remove the past from the logits.
                logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
1472
1473
1474
1475

        if prediction_loss_only:
            return (loss, None, None)

1476
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
1477
1478
1479
1480
        if len(logits) == 1:
            logits = logits[0]

        if has_labels:
1481
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
Sylvain Gugger's avatar
Sylvain Gugger committed
1482
1483
1484
1485
1486
1487
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        return (loss, logits, labels)
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
        For models that inherit from :class:`~transformers.PretrainedModel`, uses
        that method to compute the number of floating point operations for every backward + forward pass. If using
        another model, either implement such a method in the model or subclass and override this method.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

        Returns:
            :obj:`int`: The number of floating-point operations.
        """

Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1505
        model = self._actual_model(self.model)
1506
1507
1508
1509
1510
1511

        if hasattr(model, "floating_point_ops"):
            return model.floating_point_ops(inputs)

        else:
            return 0
Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530

    @staticmethod
    def _actual_model(
        model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
    ) -> torch.nn.modules.Module:
        """

        Args:
            model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
                Model object used during training

        Returns:
            :obj:`torch.nn.modules.Module`: unwrapped module
        """
        if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module
        else:
            model = model
        return model