trainer.py 99 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 馃 Transformers from scratch or finetune it on a new task.
"""

19
import collections
20
import inspect
21
import math
Julien Chaumond's avatar
Julien Chaumond committed
22
23
24
import os
import re
import shutil
25
import sys
26
import time
27
import warnings
28
from logging import StreamHandler
Julien Chaumond's avatar
Julien Chaumond committed
29
from pathlib import Path
30
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
31
32


33
34
# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
35
    default_hp_search_backend,
36
    get_reporting_integration_callbacks,
37
    hp_params,
38
    is_fairscale_available,
39
    is_optuna_available,
40
    is_ray_tune_available,
41
42
    run_hp_search_optuna,
    run_hp_search_ray,
43
44
    deepspeed_init,
    is_deepspeed_zero3_enabled,
45
)
46
47
48
49
50
51
52
53
54
55
56

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
57
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
58
59
60
61
62
from .file_utils import (
    WEIGHTS_NAME,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
Sylvain Gugger's avatar
Sylvain Gugger committed
63
64
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
Sylvain Gugger's avatar
Sylvain Gugger committed
65
    is_torch_tpu_available,
66
    is_training_run_on_sagemaker,
Sylvain Gugger's avatar
Sylvain Gugger committed
67
)
68
from .modeling_utils import PreTrainedModel, unwrap_model
Sylvain Gugger's avatar
Sylvain Gugger committed
69
from .optimization import Adafactor, AdamW, get_scheduler
70
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
71
72
73
74
75
76
77
78
79
80
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
81
    DistributedLengthGroupedSampler,
82
    DistributedSamplerWithLoop,
83
    DistributedTensorGatherer,
84
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
85
    LabelSmoother,
86
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
87
88
89
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
90
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
91
92
93
94
95
96
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
97
98
99
100
101
102
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
103
    ShardedDDPOption,
104
    TrainerMemoryTracker,
105
106
107
    TrainOutput,
    default_compute_objective,
    default_hp_space,
108
    denumpify_detensorize,
109
    get_last_checkpoint,
110
    set_seed,
111
    speed_metrics,
112
)
113
from .training_args import ParallelMode, TrainingArguments
Lysandre Debut's avatar
Lysandre Debut committed
114
from .utils import logging
115
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
Julien Chaumond's avatar
Julien Chaumond committed
116
117


118
_is_native_amp_available = False
119

Sylvain Gugger's avatar
Sylvain Gugger committed
120
DEFAULT_CALLBACKS = [DefaultFlowCallback]
121
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
122

123
124
125
126
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
127

128
129
if is_apex_available():
    from apex import amp
130

131
if version.parse(torch.__version__) >= version.parse("1.6"):
132
    _is_native_amp_available = True
133
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
134

135
136
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
137

138
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
139
140
141
142
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

143
if is_fairscale_available():
144
    dep_version_check("fairscale")
145
    import fairscale
146
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
147
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
148
    from fairscale.nn.wrap import auto_wrap
149
150
151
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

Sylvain Gugger's avatar
Sylvain Gugger committed
152
if is_sagemaker_dp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
153
154
155
156
    import smdistributed.dataparallel.torch.distributed as dist
    from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
else:
    import torch.distributed as dist
157

Sylvain Gugger's avatar
Sylvain Gugger committed
158
159
160
161
162
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat

163
164
165
166
if is_training_run_on_sagemaker():
    logging.add_handler(StreamHandler(sys.stdout))


167
168
169
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
170
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
171
172
173
174


class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
175
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 馃 Transformers.
176
177

    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
178
        model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
179
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
180
181
182
183
184
185

            .. note::

                :class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
                provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
                they work the same way as the 馃 Transformers models.
186
        args (:class:`~transformers.TrainingArguments`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
187
188
189
            The arguments to tweak for training. Will default to a basic instance of
            :class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in
            the current directory if not provided.
190
        data_collator (:obj:`DataCollator`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
191
192
193
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
            Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
            :func:`~transformers.DataCollatorWithPadding` otherwise.
194
        train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`):
195
            The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
196
            ``model.forward()`` method are automatically removed.
197
198
199
200
201
202

            Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
            training in a distributed fashion, your iterable dataset should either use a internal attribute
            :obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identic on all
            processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
            :obj:`set_epoch()` method that internally sets the seed of the RNGs used.
Sylvain Gugger's avatar
Sylvain Gugger committed
203
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
204
             The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
Sylvain Gugger's avatar
Sylvain Gugger committed
205
             ``model.forward()`` method are automatically removed.
206
207
208
209
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
210
211
212
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
213

Sylvain Gugger's avatar
Sylvain Gugger committed
214
215
            The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
            able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
            layers, dropout probabilities etc).
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
218
            The function that will be used to compute metrics at evaluation. Must take a
Sylvain Gugger's avatar
Sylvain Gugger committed
219
220
221
222
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in :doc:`here <callback>`.
Sylvain Gugger's avatar
Sylvain Gugger committed
223
224

            If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
Sylvain Gugger's avatar
Sylvain Gugger committed
225
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
226
            containing the optimizer and the scheduler to use. Will default to an instance of
227
            :class:`~transformers.AdamW` on your model and a scheduler given by
Sylvain Gugger's avatar
Sylvain Gugger committed
228
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
229

230
231
232
233
234
235
236
237
238
239
    Important attributes:

        - **model** -- Always points to the core model. If using a transformers model, it will be a
          :class:`~transformers.PreTrainedModel` subclass.
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
          original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
          the inner model is wrapped in ``DeepSpeed`` and then again in ``torch.nn.DistributedDataParallel``. If the
          inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
240
241
242
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
          to :obj:`False` if model parallel or deepspeed is used, or if the default
          ``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
243
244
        - **is_in_train** -- Whether or not a model is currently running ``train`` (e.g. when ``evaluate`` is called
          while in ``train``)
245

Julien Chaumond's avatar
Julien Chaumond committed
246
247
    """

248
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
249

Julien Chaumond's avatar
Julien Chaumond committed
250
251
    def __init__(
        self,
Sylvain Gugger's avatar
Sylvain Gugger committed
252
        model: Union[PreTrainedModel, torch.nn.Module] = None,
253
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
254
255
256
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
257
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
258
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
259
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
260
        callbacks: Optional[List[TrainerCallback]] = None,
261
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
Julien Chaumond's avatar
Julien Chaumond committed
262
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
263
        if args is None:
264
265
266
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
267
268
269
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
270
        self.hp_name = None
271
        self.deepspeed = None
272
        self.is_in_train = False
273

274
275
276
277
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

278
279
280
        # force device and distributed setup init explicitly
        args._setup_devices

281
282
283
284
285
286
287
288
289
290
291
292
293
294
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. "
                    "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
                    FutureWarning,
                )
            self.model_init = model_init
295

296
297
298
299
300
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )

            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

325
        # one place to sort out whether to place the model on device or not
326
327
328
329
330
331
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
        # 3. full fp16 eval - since the model needs to be half'ed first
        # 4. Sharded DDP - same as MP
332
        self.place_model_on_device = args.place_model_on_device
333
334
335
336
337
338
        if (
            self.is_model_parallel
            or (args.deepspeed and args.do_train)
            or (args.fp16_full_eval and not args.do_train)
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
        ):
339
340
            self.place_model_on_device = False

341
342
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
343
344
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
345
        self.tokenizer = tokenizer
346

347
        if self.place_model_on_device:
348
            model = model.to(args.device)
Stas Bekman's avatar
Stas Bekman committed
349
350
351

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
352
            self.args._n_gpu = 1
353
354
355
356
357

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
358
        self.compute_metrics = compute_metrics
359
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
360
361
362
363
364
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
365
366
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
367
368
369
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
370
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
371

372
373
374
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

Julien Chaumond's avatar
Julien Chaumond committed
375
        # Create output directory if needed
376
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
377
            os.makedirs(self.args.output_dir, exist_ok=True)
378
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
379
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
380

381
382
383
384
385
386
387
388
389
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

        # Enforce rules on using datasets with no __len__
        if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")

390
        self._signature_columns = None
391
392
        if is_datasets_available():
            if isinstance(train_dataset, datasets.Dataset):
393
                self._remove_unused_columns(self.train_dataset, description="training")
394
            if isinstance(eval_dataset, datasets.Dataset):
395
396
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

397
398
399
        # Mixed precision setup
        self.use_apex = False
        self.use_amp = False
400
401
        self.fp16_backend = None

402
403
        if args.fp16:
            if args.fp16_backend == "auto":
404
                self.fp16_backend = "amp" if _is_native_amp_available else "apex"
405
            else:
406
407
                self.fp16_backend = args.fp16_backend
            logger.info(f"Using {self.fp16_backend} fp16 backend")
408

409
410
        if args.fp16 and not args.deepspeed:  # deepspeed manages its own fp16
            if self.fp16_backend == "amp":
411
                self.use_amp = True
412
                self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
413
414
415
416
417
418
419
            else:
                if not is_apex_available():
                    raise ImportError(
                        "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
                    )
                self.use_apex = True

Sylvain Gugger's avatar
Sylvain Gugger committed
420
421
422
423
424
425
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

426
        self.state = TrainerState()
Sylvain Gugger's avatar
Sylvain Gugger committed
427
        self.control = TrainerControl()
428
429
430
        # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
        # state at each call to self.log.
        self._total_flos = None
431
        self.hp_search_backend = None
432
        self.use_tune_checkpoints = False
433
        default_label_names = (
434
            ["start_positions", "end_positions"]
435
            if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
436
437
438
            else ["labels"]
        )
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
439
440
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

441
442
443
        # very last
        self._memory_tracker.stop_and_update_metrics()

Sylvain Gugger's avatar
Sylvain Gugger committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def add_callback(self, callback):
        """
        Add a callback to the current list of :class:`~transformer.TrainerCallback`.

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will instantiate a member of that class.
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
        Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.

        If the callback is not found, returns :obj:`None` (and no error is raised).

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will pop the first member of that class found in the list of callbacks.

        Returns:
            :class:`~transformer.TrainerCallback`: The callback removed, if found.
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
        Remove a callback from the current list of :class:`~transformer.TrainerCallback`.

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will remove the first member of that class found in the list of callbacks.
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
481

482
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
483
484
        if not self.args.remove_unused_columns:
            return
485
486
487
488
489
490
491
492
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += ["label", "label_ids"]
        columns = [k for k in self._signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
493
494
495
496
497
498
        if len(ignored_columns) > 0:
            dset_description = "" if description is None else f"in the {description} set "
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
            )
499
500

        dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])
501

502
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
503
        if not isinstance(self.train_dataset, collections.abc.Sized):
504
            return None
505
506
507

        # Build the sampler.
        if self.args.group_by_length:
508
509
510
511
512
513
514
515
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
516
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
517
            if self.args.world_size <= 1:
518
                return LengthGroupedSampler(
519
                    self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name
520
                )
521
522
            else:
                return DistributedLengthGroupedSampler(
523
524
                    self.train_dataset,
                    self.args.train_batch_size,
525
526
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
527
                    lengths=lengths,
528
                    model_input_name=model_input_name,
529
530
531
                )

        else:
532
            if self.args.world_size <= 1:
533
                return RandomSampler(self.train_dataset)
Sylvain Gugger's avatar
Sylvain Gugger committed
534
535
536
537
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
538
539
540
541
542
543
544
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
                )
545
            else:
546
547
548
                return DistributedSampler(
                    self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
                )
549
550
551
552
553

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

Sylvain Gugger's avatar
Sylvain Gugger committed
554
555
        Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
        to distributed training if necessary) otherwise.
556
557
558
559
560

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580

        if isinstance(self.train_dataset, torch.utils.data.dataset.IterableDataset):
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
                    self.train_dataset,
                    batch_size=self.args.train_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            else:
                train_dataset = self.train_dataset
            return DataLoader(
                train_dataset,
                batch_size=self.args.train_batch_size,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

581
582
583
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
584
585
586
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
587
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
588
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
589
            num_workers=self.args.dataloader_num_workers,
590
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
591
592
        )

593
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
594
        if is_torch_tpu_available():
595
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
Sylvain Gugger's avatar
Sylvain Gugger committed
596
597
598
599
600
601
602
        elif is_sagemaker_mp_enabled():
            return SequentialDistributedSampler(
                eval_dataset,
                num_replicas=smp.dp_size(),
                rank=smp.dp_rank(),
                batch_size=self.args.per_device_eval_batch_size,
            )
603
604
605
606
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
607

Julien Chaumond's avatar
Julien Chaumond committed
608
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
609
610
611
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

612
613
        Subclass and override this method if you want to inject some custom behavior.

614
        Args:
615
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
616
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
617
                accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
618
        """
Julien Chaumond's avatar
Julien Chaumond committed
619
620
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
621
622
623
        elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")
        elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
624
            self._remove_unused_columns(eval_dataset, description="evaluation")
625
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
626
        eval_sampler = self._get_eval_sampler(eval_dataset)
627

628
        return DataLoader(
629
            eval_dataset,
630
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
631
            batch_size=self.args.eval_batch_size,
632
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
633
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
634
            num_workers=self.args.dataloader_num_workers,
635
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
636
637
638
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
639
640
641
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

642
643
        Subclass and override this method if you want to inject some custom behavior.

644
        Args:
645
            test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
646
                The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
647
                ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
648
        """
649
650
651
        if not isinstance(test_dataset, collections.abc.Sized):
            raise ValueError("test_dataset must implement __len__")
        elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
652
            self._remove_unused_columns(test_dataset, description="test")
653
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
654

655
656
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
657
            test_dataset,
658
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
659
            batch_size=self.args.eval_batch_size,
660
            collate_fn=self.data_collator,
661
            drop_last=self.args.dataloader_drop_last,
662
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
663
        )
Lysandre Debut's avatar
Lysandre Debut committed
664

665
    def create_optimizer_and_scheduler(self, num_training_steps: int):
666
667
668
        """
        Setup the optimizer and the learning rate scheduler.

669
670
671
672
673
674
675
676
677
678
679
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method (or :obj:`create_optimizer`
        and/or :obj:`create_scheduler`) in a subclass.
        """
        self.create_optimizer()
        self.create_scheduler(num_training_steps)

    def create_optimizer(self):
        """
        Setup the optimizer.

680
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
681
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
682
        """
683
        if self.optimizer is None:
684
685
            decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
686
687
            optimizer_grouped_parameters = [
                {
688
                    "params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
689
690
691
                    "weight_decay": self.args.weight_decay,
                },
                {
692
                    "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
693
694
695
                    "weight_decay": 0.0,
                },
            ]
Sylvain Gugger's avatar
Sylvain Gugger committed
696
697
698
699
700
701
702
703
704
705
706
            optimizer_cls = Adafactor if self.args.adafactor else AdamW
            if self.args.adafactor:
                optimizer_cls = Adafactor
                optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
            else:
                optimizer_cls = AdamW
                optimizer_kwargs = {
                    "betas": (self.args.adam_beta1, self.args.adam_beta2),
                    "eps": self.args.adam_epsilon,
                }
            optimizer_kwargs["lr"] = self.args.learning_rate
707
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
708
709
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
710
711
                    optim=optimizer_cls,
                    **optimizer_kwargs,
712
713
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
714
715
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

Sylvain Gugger's avatar
Sylvain Gugger committed
716
717
718
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

719
720
721
722
723
724
725
    def create_scheduler(self, num_training_steps: int):
        """
        Setup the scheduler. The optimizer of the trainer must have been set up before this method is called.

        Args:
            num_training_steps (int): The number of training steps to do.
        """
726
        if self.lr_scheduler is None:
727
728
729
730
731
732
            warmup_steps = (
                self.args.warmup_steps
                if self.args.warmup_steps > 0
                else math.ceil(num_training_steps * self.args.warmup_ratio)
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
733
734
735
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
                self.optimizer,
736
                num_warmup_steps=warmup_steps,
Sylvain Gugger's avatar
Sylvain Gugger committed
737
                num_training_steps=num_training_steps,
738
            )
Julien Chaumond's avatar
Julien Chaumond committed
739

740
    def num_examples(self, dataloader: DataLoader) -> int:
741
        """
742
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
743

744
        Will raise an exception if the underlying dataset does not implement method :obj:`__len__`
745
        """
746
        return len(dataloader.dataset)
747

748
749
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
750
751
        self._trial = trial

752
753
        if self.hp_search_backend is None or trial is None:
            return
754
755
756
757
758
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
759

760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
778
        self.objective = self.compute_objective(metrics.copy())
779
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
780
781
            import optuna

782
783
784
785
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
786
787
            from ray import tune

788
            if self.control.should_save:
789
                self._tune_save_checkpoint()
790
791
            tune.report(objective=self.objective, **metrics)

792
    def _tune_save_checkpoint(self):
793
794
        from ray import tune

795
796
        if not self.use_tune_checkpoints:
            return
797
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
798
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
799
            self.save_model(output_dir)
800
            if self.is_world_process_zero():
801
                self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
802
803
804
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

805
806
807
808
809
810
811
    def call_model_init(self, trial=None):
        model_init_argcount = len(inspect.signature(self.model_init).parameters)
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
812
813
814
815
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
816
817
818

        return model

819
    def _wrap_model(self, model, training=True):
Sylvain Gugger's avatar
Sylvain Gugger committed
820
821
822
823
824
825
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

826
827
        # already initialized its own DDP and AMP
        if self.deepspeed:
828
            return self.deepspeed
829

830
831
832
833
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

834
835
836
837
838
839
840
841
842
843
844
845
846
847
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
848
849
850
851
852
853
854
855
856
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
                mixed_precision = self.args.fp16
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
857
858
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
859
                self.model = model = FullyShardedDDP(
860
861
862
863
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
864
865
                ).to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
866
        elif is_sagemaker_dp_enabled():
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
            model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
        elif self.args.local_rank != -1:
            if self.args.ddp_find_unused_parameters is not None:
                find_unused_parameters = self.args.ddp_find_unused_parameters
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
                find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False)
            else:
                find_unused_parameters = True
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=find_unused_parameters,
            )

        return model

886
887
    def train(
        self,
888
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
889
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
890
        **kwargs,
891
    ):
Julien Chaumond's avatar
Julien Chaumond committed
892
893
894
895
        """
        Main training entry point.

        Args:
896
897
898
899
900
            resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
                If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
                :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
                `args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
                training will resume from the model/optimizer/scheduler states loaded here.
901
902
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
903
904
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
905
        """
906
907
908
909

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

910
911
        self.is_in_train = True

912
913
914
915
916
917
918
919
920
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
921
922
923
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

924
        # Model re-init
925
        model_reloaded = False
926
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
927
928
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
929
930
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
931
932
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
933

934
        # Load potential model checkpoint
935
936
937
938
939
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
            resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
            if resume_from_checkpoint is None:
                raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")

940
941
942
943
        if resume_from_checkpoint is not None:
            if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
                raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

944
            logger.info(f"Loading model from {resume_from_checkpoint}).")
945
946

            if self.deepspeed:
947
                # will be resumed in deepspeed_init
948
949
                pass
            elif isinstance(self.model, PreTrainedModel):
950
                self.model = self.model.from_pretrained(resume_from_checkpoint)
951
952
                model_reloaded = True
            else:
953
                state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
954
955
956
957
                self.model.load_state_dict(state_dict)

        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
958
            if self.place_model_on_device:
959
960
961
                self.model = self.model.to(self.args.device)
            self.model_wrapped = self.model

962
963
964
        # Keeping track whether we can can len() on the dataset or not
        train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)

965
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
966
        train_dataloader = self.get_train_dataloader()
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        if train_dataset_is_sized:
            num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            if self.args.max_steps > 0:
                max_steps = self.args.max_steps
                num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                    self.args.max_steps % num_update_steps_per_epoch > 0
                )
            else:
                max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(self.args.num_train_epochs)
Julien Chaumond's avatar
Julien Chaumond committed
983
        else:
984
985
986
987
            # see __init__. max_steps is set when the dataset has no __len__
            max_steps = self.args.max_steps
            num_train_epochs = 1
            num_update_steps_per_epoch = max_steps
Julien Chaumond's avatar
Julien Chaumond committed
988

989
        delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
990
        if self.args.deepspeed:
991
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
992
993
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
994
995
996
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
997
998
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
999
        elif not delay_optimizer_creation:
1000
1001
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1002
        self.state = TrainerState()
1003
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1004

1005
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1006

1007
1008
1009
1010
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1011
1012
1013
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1014
1015
1016
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1017
1018
        # important: at this point:
        # self.model         is the Transformers Model
1019
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1020

Julien Chaumond's avatar
Julien Chaumond committed
1021
        # Train!
1022
        if is_torch_tpu_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
1023
1024
1025
            world_size = xm.xrt_world_size()
        elif self.args.local_rank != -1:
            world_size = dist.get_world_size()
Lysandre Debut's avatar
Lysandre Debut committed
1026
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1027
            world_size = 1
1028

Sylvain Gugger's avatar
Sylvain Gugger committed
1029
        total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size
1030
1031
1032
1033
1034
1035
        num_examples = (
            self.num_examples(train_dataloader)
            if train_dataset_is_sized
            else total_train_batch_size * self.args.max_steps
        )

Julien Chaumond's avatar
Julien Chaumond committed
1036
        logger.info("***** Running training *****")
1037
1038
1039
1040
1041
1042
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
        logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {max_steps}")
Julien Chaumond's avatar
Julien Chaumond committed
1043

1044
        self.state.epoch = 0
1045
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1046
1047
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1048

Julien Chaumond's avatar
Julien Chaumond committed
1049
        # Check if continuing training from a checkpoint
1050
1051
1052
1053
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, "trainer_state.json")
        ):
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
1054
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1055
1056
1057
1058
1059
            if not self.args.ignore_data_skip:
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
                steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
            else:
                steps_trained_in_current_epoch = 0
1060
1061

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1062
1063
1064
1065
1066
1067
1068
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not self.args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
                    "batches in the first epoch."
                )
1069

Sylvain Gugger's avatar
Sylvain Gugger committed
1070
1071
1072
1073
1074
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1075
1076
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
        self.state.trial_params = hp_params(trial) if trial is not None else None
1077
1078
1079
1080
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1081
1082
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1083

1084
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1085
        tr_loss = torch.tensor(0.0).to(self.args.device)
1086
1087
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1088
        self._globalstep_last_logged = self.state.global_step
1089
        self._total_flos = self.state.total_flos
Julien Chaumond's avatar
Julien Chaumond committed
1090
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1091
1092
1093

        self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)

1094
1095
1096
1097
1098
1099
1100
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not self.args.ignore_data_skip:
            for epoch in range(epochs_trained):
                # We just need to begin an iteration to create the randomization of the sampler.
                for _ in train_dataloader:
                    break

1101
        for epoch in range(epochs_trained, num_train_epochs):
1102
1103
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1104
1105
            elif isinstance(train_dataloader.dataset, IterableDatasetShard):
                train_dataloader.dataset.set_epoch(epoch)
1106

1107
            if is_torch_tpu_available():
1108
1109
1110
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
1111
                epoch_iterator = parallel_loader
1112
            else:
1113
                epoch_iterator = train_dataloader
1114

1115
1116
1117
1118
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

1119
1120
1121
1122
1123
            steps_in_epoch = (
                len(epoch_iterator)
                if train_dataset_is_sized
                else self.args.max_steps * self.args.gradient_accumulation_steps
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1124
1125
            self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)

Julien Chaumond's avatar
Julien Chaumond committed
1126
1127
1128
1129
1130
1131
1132
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

1133
                if step % self.args.gradient_accumulation_steps == 0:
Sylvain Gugger's avatar
Sylvain Gugger committed
1134
1135
                    self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)

1136
1137
1138
                if (
                    ((step + 1) % self.args.gradient_accumulation_steps != 0)
                    and self.args.local_rank != -1
Sylvain Gugger's avatar
Sylvain Gugger committed
1139
                    and self.args._no_sync_in_gradient_accumulation
1140
                ):
1141
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1142
1143
1144
1145
                    with model.no_sync():
                        tr_loss += self.training_step(model, inputs)
                else:
                    tr_loss += self.training_step(model, inputs)
1146
                self._total_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1147

1148
1149
1150
1151
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

Julien Chaumond's avatar
Julien Chaumond committed
1152
1153
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1154
1155
                    steps_in_epoch <= self.args.gradient_accumulation_steps
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1156
                ):
1157
                    # Gradient clipping
1158
1159
1160
                    if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed:
                        # deepspeed does its own clipping

1161
1162
1163
1164
1165
1166
1167
                        if self.use_amp:
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

                        if hasattr(self.optimizer, "clip_grad_norm"):
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
                            self.optimizer.clip_grad_norm(self.args.max_grad_norm)
1168
1169
1170
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
                            model.clip_grad_norm_(self.args.max_grad_norm)
1171
1172
1173
1174
1175
1176
1177
1178
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
                                self.args.max_grad_norm,
                            )

                    # Optimizer step
1179
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
1180
                    if self.deepspeed:
1181
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
1182
                    elif is_torch_tpu_available():
1183
                        xm.optimizer_step(self.optimizer)
1184
                    elif self.use_amp:
1185
                        scale_before = self.scaler.get_scale()
1186
                        self.scaler.step(self.optimizer)
1187
                        self.scaler.update()
1188
1189
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1190
                    else:
1191
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
1192

1193
                    if optimizer_was_run and not self.deepspeed:
1194
1195
                        self.lr_scheduler.step()

Julien Chaumond's avatar
Julien Chaumond committed
1196
                    model.zero_grad()
1197
                    self.state.global_step += 1
1198
                    self.state.epoch = epoch + (step + 1) / steps_in_epoch
Sylvain Gugger's avatar
Sylvain Gugger committed
1199
1200
                    self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)

1201
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
Julien Chaumond's avatar
Julien Chaumond committed
1202

Sylvain Gugger's avatar
Sylvain Gugger committed
1203
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1204
                    break
1205

Sylvain Gugger's avatar
Sylvain Gugger committed
1206
            self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
1207
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
1208

1209
            if self.args.tpu_metrics_debug or self.args.debug:
1210
1211
1212
1213
1214
1215
1216
1217
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1218
            if self.control.should_training_stop:
1219
                break
Julien Chaumond's avatar
Julien Chaumond committed
1220

1221
1222
1223
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1224
1225

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
1226
        if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
1227
1228
1229
1230
1231
1232
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
            elif self.args.local_rank != -1:
                dist.barrier()

1233
1234
1235
            logger.info(
                f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
            )
1236
1237
            if isinstance(self.model, PreTrainedModel):
                self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
1238
                if self.place_model_on_device:
1239
                    self.model = self.model.to(self.args.device)
1240
1241
1242
1243
            else:
                state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
                self.model.load_state_dict(state_dict)

1244
1245
1246
1247
1248
            if self.deepspeed:
                self.deepspeed.load_checkpoint(
                    self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
                )

1249
        metrics = speed_metrics("train", start_time, self.state.max_steps)
1250
1251
        if self._total_flos is not None:
            self.store_flos()
1252
1253
            metrics["total_flos"] = self.state.total_flos
        self.log(metrics)
1254

Sylvain Gugger's avatar
Sylvain Gugger committed
1255
        self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
1256
1257
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
Sylvain Gugger's avatar
Sylvain Gugger committed
1258

1259
        self.is_in_train = False
1260

1261
1262
        self._memory_tracker.stop_and_update_metrics(metrics)

1263
        return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
1264

1265
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
Sylvain Gugger's avatar
Sylvain Gugger committed
1266
1267
1268
        if self.control.should_log:
            logs: Dict[str, float] = {}
            tr_loss_scalar = tr_loss.item()
1269
1270
1271
            # reset tr_loss to zero
            tr_loss -= tr_loss

1272
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
1273
            logs["learning_rate"] = self._get_learning_rate()
1274

1275
            self._total_loss_scalar += tr_loss_scalar
1276
            self._globalstep_last_logged = self.state.global_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1277
1278
1279
1280
1281
1282
1283

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            metrics = self.evaluate()
            self._report_to_hp_search(trial, epoch, metrics)
1284

Sylvain Gugger's avatar
Sylvain Gugger committed
1285
1286
1287
1288
1289
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

    def _save_checkpoint(self, model, trial, metrics=None):
1290
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
1291
        # want to save except FullyShardedDDP.
1292
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
1293

1294
        # Save model checkpoint
1295
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1296

1297
        if self.hp_search_backend is not None and trial is not None:
1298
1299
1300
1301
1302
1303
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            else:
                from ray import tune

                run_id = tune.get_trial_id()
1304
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
1305
            run_dir = os.path.join(self.args.output_dir, run_name)
1306
        else:
1307
            run_dir = self.args.output_dir
1308
            self.store_flos()
1309

1310
        output_dir = os.path.join(run_dir, checkpoint_folder)
1311
        self.save_model(output_dir)
1312
        if self.deepspeed:
1313
1314
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
            # config `stage3_gather_fp16_weights_on_model_save` is True
1315
            self.deepspeed.save_checkpoint(output_dir)
1316
1317

        # Save optimizer and scheduler
1318
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1319
            self.optimizer.consolidate_state_dict()
1320

1321
1322
1323
1324
1325
1326
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
1327
1328
1329
1330
1331
1332
1333
1334
1335
        elif is_sagemaker_mp_enabled():
            # Consolidate the state dict on all processed of dp_rank 0
            opt_state_dict = self.optimizer.state_dict()
            # Save it and the scheduler on the main process
            if self.is_world_process_zero():
                torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                reissue_pt_warnings(caught_warnings)
1336
1337
        elif self.is_world_process_zero() and not self.deepspeed:
            # deepspeed.save_checkpoint above saves model/optim/sched
1338
1339
1340
1341
1342
1343
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
            reissue_pt_warnings(caught_warnings)

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
1344
        if metrics is not None and self.args.metric_for_best_model is not None:
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.is_world_process_zero():
            self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))

        # Maybe delete some older checkpoints.
        if self.is_world_process_zero():
1365
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
1366

1367
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
1368
        """If optimizer and scheduler states exist, load them."""
1369
        if checkpoint is None:
1370
1371
            return

1372
        if self.deepspeed:
1373
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
1374
1375
            return

1376
1377
        if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
            os.path.join(checkpoint, "scheduler.pt")
Sylvain Gugger's avatar
Sylvain Gugger committed
1378
1379
1380
1381
        ):
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
1382
                optimizer_state = torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
1383
                with warnings.catch_warnings(record=True) as caught_warnings:
1384
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, "scheduler.pt"), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
1385
1386
1387
1388
1389
1390
1391
1392
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1393
                map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
Sylvain Gugger's avatar
Sylvain Gugger committed
1394
                self.optimizer.load_state_dict(
Sylvain Gugger's avatar
Sylvain Gugger committed
1395
                    torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=map_location)
Sylvain Gugger's avatar
Sylvain Gugger committed
1396
1397
                )
                with warnings.catch_warnings(record=True) as caught_warnings:
1398
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
Sylvain Gugger's avatar
Sylvain Gugger committed
1399
1400
                reissue_pt_warnings(caught_warnings)

1401
1402
1403
1404
1405
1406
1407
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
1408
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
1409
        **kwargs,
1410
1411
    ) -> BestRun:
        """
1412
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
Sylvain Gugger's avatar
Sylvain Gugger committed
1413
1414
        :obj:`compute_objective`, which defaults to a function returning the evaluation loss when no metric is
        provided, the sum of all metrics otherwise.
1415

Sylvain Gugger's avatar
Sylvain Gugger committed
1416
1417
1418
1419
1420
1421
1422
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

Sylvain Gugger's avatar
Sylvain Gugger committed
1444
                - the documentation of `optuna.create_study
1445
                  <https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
1446
1447
                - the documentation of `tune.run
                  <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
1448
1449

        Returns:
Tiger's avatar
Tiger committed
1450
            :class:`transformers.trainer_utils.BestRun`: All the information about the best run.
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
1462
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
1463
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
1464
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1465
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
1466
1467
            )
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
1468
1469
1470
1471
1472
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

1473
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
1474
        self.hp_name = hp_name
1475
1476
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

1477
1478
        run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
        best_run = run_hp_search(self, n_trials, direction, **kwargs)
1479
1480
1481
1482

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
1483
    def log(self, logs: Dict[str, float]) -> None:
1484
1485
1486
1487
1488
1489
1490
1491
1492
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
        """
1493
        if self.state.epoch is not None:
1494
            logs["epoch"] = round(self.state.epoch, 2)
1495

1496
1497
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
1498
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
1499

sgugger's avatar
Fix CI  
sgugger committed
1500
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
1501
1502
1503
1504
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
1505
        for k, v in inputs.items():
1506
1507
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
1508

1509
1510
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
1511

1512
1513
        return inputs

1514
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
1515
        """
1516
        Perform a training step on a batch of inputs.
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
1530
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
1531
1532
        """
        model.train()
1533
        inputs = self._prepare_inputs(inputs)
1534

Sylvain Gugger's avatar
Sylvain Gugger committed
1535
1536
1537
1538
        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

1539
        if self.use_amp:
1540
            with autocast():
Sylvain Gugger's avatar
Sylvain Gugger committed
1541
                loss = self.compute_loss(model, inputs)
1542
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1543
            loss = self.compute_loss(model, inputs)
1544

Julien Chaumond's avatar
Julien Chaumond committed
1545
1546
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1547

1548
1549
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
1550
1551
            loss = loss / self.args.gradient_accumulation_steps

1552
        if self.use_amp:
1553
            self.scaler.scale(loss).backward()
1554
        elif self.use_apex:
1555
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1556
                scaled_loss.backward()
1557
        elif self.deepspeed:
1558
1559
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
1560
1561
1562
        else:
            loss.backward()

1563
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
1564

1565
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
1566
1567
1568
1569
1570
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
1571
1572
1573
1574
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1575
1576
        outputs = model(**inputs)
        # Save past state if it exists
1577
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
1578
1579
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
1580

1581
        if labels is not None:
1582
            loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
1583
1584
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
1585
1586
1587
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
1588

1589
1590
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1591
1592
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
1593
        """
1594
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1595
            return xm.is_master_ordinal(local=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
1596
1597
        elif is_sagemaker_mp_enabled():
            return smp.local_rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
1598
1599
1600
        else:
            return self.args.local_rank in [-1, 0]

1601
1602
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1603
1604
        Whether or not this process is the global main process (when training in a distributed fashion on several
        machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1605
        """
1606
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1607
            return xm.is_master_ordinal(local=False)
Sylvain Gugger's avatar
Sylvain Gugger committed
1608
1609
        elif is_sagemaker_mp_enabled():
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
1610
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1611
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
1612
1613
1614

    def save_model(self, output_dir: Optional[str] = None):
        """
1615
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1616

1617
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
1618
        """
1619
1620
1621
1622

        if output_dir is None:
            output_dir = self.args.output_dir

1623
        if is_torch_tpu_available():
1624
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
1625
1626
1627
1628
1629
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
            state_dict = self.model_wrapped.state_dict()
            if self.is_world_process_zero():
                self._save(output_dir, state_dict=state_dict)
1630
1631
1632
1633
        elif (
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
        ):
            state_dict = self.model.state_dict()
1634

1635
            if self.is_world_process_zero():
1636
                self._save(output_dir, state_dict=state_dict)
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
        elif self.deepspeed:

            # this takes care of everything as long as we aren't under zero3
            if self.is_world_process_zero():
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
                if self.is_world_process_zero():
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

                # now save the real model if stage3_gather_fp16_weights_on_model_save=True
                # if false it will not be saved.
                # This must be called on all ranks
                self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)

1659
1660
        elif self.is_world_process_zero():
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1661

1662
1663
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
1664
        logger.info(f"Saving model checkpoint to {output_dir}")
1665
1666
1667
1668
1669
1670
1671
1672

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
1673
        if not isinstance(self.model, PreTrainedModel):
1674
1675
1676
1677
1678
1679
1680
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
                    save_config=self.is_world_process_zero(),
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
1681
1682
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
1683
1684
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
1685
        else:
1686
            self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
Sylvain Gugger's avatar
Sylvain Gugger committed
1687
        if self.tokenizer is not None and self.is_world_process_zero():
1688
            self.tokenizer.save_pretrained(output_dir)
1689

1690
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
1691
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
1692
1693
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
1694
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
1695
1696
1697
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
1698
            if isinstance(unwrap_model(self.model), PreTrainedModel):
1699
1700
1701
                if state_dict is None:
                    state_dict = self.model.state_dict()
                unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
1702
1703
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
1704
1705
                if state_dict is None:
                    state_dict = self.model.state_dict()
1706
                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
1707
        else:
1708
            self.model.save_pretrained(output_dir, state_dict=state_dict)
1709
        if self.tokenizer is not None:
1710
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1711
1712
1713

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1714

1715
    def store_flos(self):
1716
        # Storing the number of floating-point operations that went into the model
1717
        if self._total_flos is not None:
1718
            if self.args.local_rank != -1:
1719
                self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
1720
            else:
1721
                self.state.total_flos = self._total_flos
Julien Chaumond's avatar
Julien Chaumond committed
1722

1723
1724
1725
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
1726
1727
        ordering_and_checkpoint_path = []

1728
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1729
1730
1731
1732
1733

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1734
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1735
1736
1737
1738
1739
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
1740
1741
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
1742
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
1743
            checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = (
1744
1745
1746
                checkpoints_sorted[-1],
                checkpoints_sorted[best_model_index],
            )
Julien Chaumond's avatar
Julien Chaumond committed
1747
1748
        return checkpoints_sorted

1749
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
1750
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1751
1752
1753
            return

        # Check if we should delete older checkpoint(s)
1754
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1755
1756
1757
1758
1759
1760
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
1761
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
Julien Chaumond's avatar
Julien Chaumond committed
1762
1763
            shutil.rmtree(checkpoint)

1764
    def evaluate(
1765
1766
1767
1768
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
1769
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1770
        """
1771
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1772

Sylvain Gugger's avatar
Sylvain Gugger committed
1773
1774
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
        (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1775

1776
1777
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1778
        Args:
1779
            eval_dataset (:obj:`Dataset`, `optional`):
1780
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
Sylvain Gugger's avatar
Sylvain Gugger committed
1781
1782
                columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
                :obj:`__len__` method.
1783
1784
1785
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
1786
1787
1788
            metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
1789

Julien Chaumond's avatar
Julien Chaumond committed
1790
        Returns:
1791
1792
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
1793
        """
1794
1795
1796
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1797
1798
1799
        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")

Julien Chaumond's avatar
Julien Chaumond committed
1800
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
1801
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1802

1803
1804
1805
1806
1807
1808
        output = self.prediction_loop(
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
1809
            ignore_keys=ignore_keys,
1810
            metric_key_prefix=metric_key_prefix,
1811
        )
Lysandre Debut's avatar
Lysandre Debut committed
1812

1813
1814
        n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
        output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
1815
        self.log(output.metrics)
1816

1817
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1818
1819
1820
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
1821
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
1822
1823
1824

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
1825
1826
        return output.metrics

1827
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
1828
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
1829
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
1830
        """
1831
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1832

Sylvain Gugger's avatar
Sylvain Gugger committed
1833
1834
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
        will also return metrics, like in :obj:`evaluate()`.
1835
1836
1837

        Args:
            test_dataset (:obj:`Dataset`):
1838
                Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
1839
                ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
1840
1841
1842
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
Bhadresh Savani's avatar
Bhadresh Savani committed
1843
            metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"test"`):
1844
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
1845
                "test_bleu" if the prefix is "test" (default)
1846

1847
1848
1849
1850
1851
1852
        .. note::

            If your predictions or labels have different sequence length (for instance because you're doing dynamic
            padding in a token classification task) the predictions will be padded (on the right) to allow for
            concatenation into one array. The padding index is -100.

Sylvain Gugger's avatar
Sylvain Gugger committed
1853
1854
1855
1856
1857
1858
        Returns: `NamedTuple` A namedtuple with the following keys:

            - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
            - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
            - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
              contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1859
        """
1860
1861
1862
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1863
1864
1865
        if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
            raise ValueError("test_dataset must implement __len__")

Julien Chaumond's avatar
Julien Chaumond committed
1866
        test_dataloader = self.get_test_dataloader(test_dataset)
1867
        start_time = time.time()
1868

1869
        output = self.prediction_loop(
1870
1871
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
1872
        output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
1873
1874
1875

        self._memory_tracker.stop_and_update_metrics(output.metrics)

1876
        return output
Julien Chaumond's avatar
Julien Chaumond committed
1877

1878
    def prediction_loop(
1879
1880
1881
1882
1883
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
1884
        metric_key_prefix: str = "eval",
Julien Chaumond's avatar
Julien Chaumond committed
1885
1886
    ) -> PredictionOutput:
        """
1887
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1888
1889
1890

        Works both with or without labels.
        """
1891
1892
        if not isinstance(dataloader.dataset, collections.abc.Sized):
            raise ValueError("dataset must implement __len__")
1893
1894
1895
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1896

1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
        # if eval is called w/o train init deepspeed here
        if self.args.deepspeed and not self.deepspeed:

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None
1911

1912
        model = self._wrap_model(self.model, training=False)
Julien Chaumond's avatar
Julien Chaumond committed
1913

1914
1915
1916
1917
1918
        # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
        # ``train`` is running, half it first and then put on device
        if not self.is_in_train and self.args.fp16_full_eval:
            model = model.half().to(self.args.device)

1919
        batch_size = dataloader.batch_size
1920
        num_examples = self.num_examples(dataloader)
1921
1922
1923
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
1924
1925
1926
1927
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None

1928
        world_size = max(1, self.args.world_size)
1929
1930

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
1931
        if not prediction_loss_only:
1932
1933
1934
1935
1936
1937
1938
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
1939

Julien Chaumond's avatar
Julien Chaumond committed
1940
1941
        model.eval()

1942
        if is_torch_tpu_available():
1943
1944
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1945
        if self.args.past_index >= 0:
1946
            self._past = None
1947

Sylvain Gugger's avatar
Sylvain Gugger committed
1948
1949
        self.callback_handler.eval_dataloader = dataloader

1950
        for step, inputs in enumerate(dataloader):
1951
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
1952
            if loss is not None:
1953
1954
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
1955
            if logits is not None:
1956
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
1957
            if labels is not None:
1958
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1959
            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1960

1961
1962
1963
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
1964
1965
1966
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
1967
1968
1969
1970

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host = None, None, None

1971
1972
1973
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1974

1975
1976
        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
1977
1978
1979
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
1980
1981

        eval_loss = eval_losses_gatherer.finalize()
1982
1983
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
Lysandre Debut's avatar
Lysandre Debut committed
1984

Julien Chaumond's avatar
Julien Chaumond committed
1985
1986
1987
1988
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
1989

1990
1991
1992
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

1993
        if eval_loss is not None:
1994
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
1995

1996
        # Prefix all keys with metric_key_prefix + '_'
1997
        for key in list(metrics.keys()):
1998
1999
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
2000
2001

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
2002

2003
2004
2005
2006
2007
2008
2009
2010
2011
    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
2012
2013
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
2014
2015
2016
2017
2018
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)

2019
    def prediction_step(
2020
2021
2022
2023
2024
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
2025
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.
2041
2042
2043
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2044
2045

        Return:
2046
2047
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
2048
        """
2049
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
2050
        inputs = self._prepare_inputs(inputs)
2051
2052
2053
2054
2055
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
2056

2057
2058
2059
2060
2061
2062
2063
2064
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

2065
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
2078
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2079
2080
2081
2082
2083
2084
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
2085
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2086
2087
2088
2089
2090
2091
2092
                if has_labels:
                    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                    loss = loss.mean().detach()
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
2093
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
                    loss = None
                    if self.use_amp:
                        with autocast():
                            outputs = model(**inputs)
                    else:
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
2107
2108
2109
2110

        if prediction_loss_only:
            return (loss, None, None)

2111
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
2112
2113
2114
2115
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
2116
2117
2118

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2119
2120
2121
        For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
        floating point operations for every backward + forward pass. If using another model, either implement such a
        method in the model or subclass and override this method.
2122
2123
2124
2125
2126
2127
2128
2129

        Args:
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

        Returns:
            :obj:`int`: The number of floating-point operations.
        """
2130
2131
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
2132
2133
        else:
            return 0