trainer.py 157 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 馃 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34

35
36
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
37

38
39
# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
40
    default_hp_search_backend,
41
    get_reporting_integration_callbacks,
42
    hp_params,
43
    is_fairscale_available,
44
    is_optuna_available,
45
    is_ray_tune_available,
46
    is_sigopt_available,
47
    is_wandb_available,
48
49
    run_hp_search_optuna,
    run_hp_search_ray,
50
    run_hp_search_sigopt,
51
    run_hp_search_wandb,
52
)
53
54
55

import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
56
import torch.distributed as dist
57
58
from packaging import version
from torch import nn
59
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
60
61
from torch.utils.data.distributed import DistributedSampler

62
63
from huggingface_hub import Repository

64
65
from . import __version__
from .configuration_utils import PretrainedConfig
66
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
67
from .debug_utils import DebugOption, DebugUnderflowOverflow
68
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
69
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
70
from .modelcard import TrainingSummary
71
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
72
from .optimization import Adafactor, get_scheduler
73
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
74
75
76
77
78
79
80
81
82
83
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
84
    DistributedLengthGroupedSampler,
85
    DistributedSamplerWithLoop,
86
    DistributedTensorGatherer,
87
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
88
    LabelSmoother,
89
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
90
    SequentialDistributedSampler,
91
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
92
93
    distributed_broadcast_scalars,
    distributed_concat,
94
    find_batch_size,
95
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
98
    nested_concat,
    nested_detach,
    nested_numpify,
99
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
100
101
102
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
103
104
105
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
106
    EvalLoopOutput,
107
    EvalPrediction,
108
    FSDPOption,
109
    HPSearchBackend,
110
111
    HubStrategy,
    IntervalStrategy,
112
    PredictionOutput,
113
    RemoveColumnsCollator,
114
    ShardedDDPOption,
115
    TrainerMemoryTracker,
116
117
118
    TrainOutput,
    default_compute_objective,
    default_hp_space,
119
    denumpify_detensorize,
120
    enable_full_determinism,
121
    find_executable_batch_size,
122
    get_last_checkpoint,
123
    has_length,
124
    number_of_arguments,
125
    seed_worker,
126
    set_seed,
127
    speed_metrics,
128
)
129
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
130
131
from .utils import (
    CONFIG_NAME,
132
    WEIGHTS_INDEX_NAME,
133
    WEIGHTS_NAME,
134
    find_labels,
135
136
137
138
    get_full_repo_name,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
139
    is_ipex_available,
140
141
142
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_tpu_available,
143
    is_torchdynamo_available,
144
145
    logging,
)
146
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
147
148


149
_is_torch_generator_available = False
150
151
_is_native_cuda_amp_available = False
_is_native_cpu_amp_available = False
152

Sylvain Gugger's avatar
Sylvain Gugger committed
153
DEFAULT_CALLBACKS = [DefaultFlowCallback]
154
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
155

156
157
158
159
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
160

161
162
if is_apex_available():
    from apex import amp
163

164
if version.parse(torch.__version__) >= version.parse("1.6"):
165
    _is_torch_generator_available = True
166
167
168
169
    _is_native_cuda_amp_available = True

if version.parse(torch.__version__) >= version.parse("1.10"):
    _is_native_cpu_amp_available = True
Julien Chaumond's avatar
Julien Chaumond committed
170

171
172
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
173

174
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
175
176
177
178
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

179
if is_fairscale_available():
180
    dep_version_check("fairscale")
181
    import fairscale
182
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
183
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
184
    from fairscale.nn.wrap import auto_wrap
185
186
187
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

188

Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
191
192
193
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat

194

195
196
197
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
198
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
199
200


201
202
203
204
205
206
207
208
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
209
210
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
211
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 馃 Transformers.
212
213

    Args:
214
215
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
216

217
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
218

Sylvain Gugger's avatar
Sylvain Gugger committed
219
220
221
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 馃 Transformers
            models.
222
223
224
225

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
226
227
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
228
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
231
232
233
234
235
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
            The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
236
237
238
239
240
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
241
242
243
244
        eval_dataset (`torch.utils.data.Dataset`, *optional*):
             The dataset to use for evaluation. If it is an `datasets.Dataset`, columns not accepted by the
             `model.forward()` method are automatically removed.
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
245
246
247
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
248
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
249
250
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
251

252
253
254
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
255
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
256
257
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
258
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
259
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
260
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
261

262
263
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
264
265
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
266
267
268
269
270
271
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
272

273
274
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
275
276
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
277
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
278
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
279
280
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
281
282
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
283
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
284
285
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
286
287
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
288

Julien Chaumond's avatar
Julien Chaumond committed
289
290
    """

291
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
292

Julien Chaumond's avatar
Julien Chaumond committed
293
294
    def __init__(
        self,
295
        model: Union[PreTrainedModel, nn.Module] = None,
296
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
297
298
299
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
300
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
301
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
302
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
303
        callbacks: Optional[List[TrainerCallback]] = None,
304
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
305
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
Julien Chaumond's avatar
Julien Chaumond committed
306
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
307
        if args is None:
308
309
310
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
311
312
        self.args = args
        # Seed must be set before instantiating the model when using model
313
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
314
        self.hp_name = None
315
        self.deepspeed = None
316
        self.is_in_train = False
317

318
319
320
321
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

322
        # set the correct log level depending on the node
323
        log_level = args.get_process_log_level()
324
325
        logging.set_verbosity(log_level)

326
327
328
        # force device and distributed setup init explicitly
        args._setup_devices

329
330
331
332
333
334
335
336
337
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
338
339
340
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
341
342
343
                    FutureWarning,
                )
            self.model_init = model_init
344

345
346
347
348
349
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

350
351
352
353
354
355
356
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
357
358
359
360
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        self.fsdp = None
        if len(args.fsdp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
            if args.local_rank == -1:
                raise ValueError("Using fsdp only works in distributed training.")

            #  dep_version_check("torch>=1.12.0.dev20220418+cu113")
            # Would have to update setup.py with torch>=1.12.0.dev20220418+cu113
            # which isn't ideally given that it's a dev version
            # and it will force people not using FSDP to also use torch>=1.12.0.dev20220418+cu113
            # below is the current alternative.
            if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"):
                raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113")

            from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP

402
        # one place to sort out whether to place the model on device or not
403
404
405
406
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
407
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
408
        # 4. Sharded DDP - same as MP
409
        # 5. FSDP - same as MP
410
        self.place_model_on_device = args.place_model_on_device
411
412
        if (
            self.is_model_parallel
413
            or args.deepspeed
414
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
415
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
416
            or (self.fsdp is not None)
417
        ):
418
419
            self.place_model_on_device = False

420
421
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
422
423
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
424
        self.tokenizer = tokenizer
425

426
        if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
427
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
428
429
430

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
431
            self.args._n_gpu = 1
432
433
434
435
436

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
437
        self.compute_metrics = compute_metrics
438
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
439
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
440
441
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
442
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
443
444
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
445
        if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
446
447
448
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
449
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
450
451
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
452
453
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
454
455
456
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
457
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
458

459
460
461
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

462
463
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
464
            self.init_git_repo(at_init=True)
465
466
467
468
469
470
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
            elif args.local_rank != -1:
                dist.barrier()

471
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
472
            os.makedirs(self.args.output_dir, exist_ok=True)
473

474
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
475
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
476

477
478
479
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

480
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
481
482
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")

483
484
485
486
487
488
489
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

490
        self._signature_columns = None
491

492
493
        # Mixed precision setup
        self.use_apex = False
494
495
        self.use_cuda_amp = False
        self.use_cpu_amp = False
496

497
498
499
500
501
502
503
504
505
506
507
508
509
510
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
            # When there's mismatch between SMP config and trainer argument, use SMP config as truth
            if args.fp16 != smp.state.cfg.fp16:
                logger.warning(
                    f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                    f"but FP16 provided in trainer argument is {args.fp16},"
                    f"setting to {smp.state.cfg.fp16}"
                )
                args.fp16 = smp.state.cfg.fp16

511
        if args.fp16 or args.bf16:
512
513
514
515
516
            if self.fsdp is not None:
                raise ValueError(
                    "Mixed precision is currently not supported for FSDP."
                    "Please do not set arguments related to `mixed_precision`"
                )
517
            if args.half_precision_backend == "auto":
518
519
520
521
522
523
524
                if args.device == torch.device("cpu"):
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
525
                else:
526
527
528
                    if _is_native_cuda_amp_available:
                        args.half_precision_backend = "cuda_amp"
                    elif args.bf16:
529
530
531
                        raise ValueError("Tried to use `bf16` but native amp is not available")
                    else:
                        args.half_precision_backend = "apex"
532

533
            logger.info(f"Using {args.half_precision_backend} half precision backend")
534

535
        self.do_grad_scaling = False
536
537
        if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
            # deepspeed and SageMaker Model Parallel manage their own half precision
538
539
            if args.half_precision_backend == "cuda_amp":
                self.use_cuda_amp = True
540
541
                self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                self.do_grad_scaling = True
542
                if self.sharded_ddp is not None:
543
                    self.scaler = ShardedGradScaler()
544
545
546
547
                elif is_torch_tpu_available():
                    from torch_xla.amp import GradScaler

                    self.scaler = GradScaler()
548
549
                else:
                    self.scaler = torch.cuda.amp.GradScaler()
550
551
552
            elif args.half_precision_backend == "cpu_amp":
                self.use_cpu_amp = True
                self.amp_dtype = torch.bfloat16
553
554
555
            else:
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
556
557
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
558
559
560
                    )
                self.use_apex = True

561
562
563
564
565
566
567
568
569
570
571
572
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
573
574
575
576
577
578
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

579
580
581
582
583
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
584
        self.control = TrainerControl()
585
586
587
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
588
        self.hp_search_backend = None
589
        self.use_tune_checkpoints = False
590
        default_label_names = find_labels(self.model.__class__)
591
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
592
593
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

594
595
596
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

597
598
599
        # very last
        self._memory_tracker.stop_and_update_metrics()

Sylvain Gugger's avatar
Sylvain Gugger committed
600
601
    def add_callback(self, callback):
        """
602
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
603
604

        Args:
605
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
606
607
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
608
609
610
611
612
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
613
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
614

615
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
616
617

        Args:
618
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
619
620
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
621
622

        Returns:
623
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
624
625
626
627
628
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
629
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
630
631

        Args:
632
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
633
634
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
635
636
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
637

Sylvain Gugger's avatar
Sylvain Gugger committed
638
639
640
641
642
643
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

644
    def _set_signature_columns_if_needed(self):
645
646
647
648
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
649
650
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
651

652
653
654
655
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
656
        signature_columns = self._signature_columns
657
658

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
659
        if len(ignored_columns) > 0:
660
            dset_description = "" if description is None else f"in the {description} set"
661
662
663
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
664
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
665
                " you can safely ignore this message."
666
            )
667

668
        columns = [k for k in signature_columns if k in dataset.column_names]
669

670
671
672
673
674
675
676
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
677

678
679
680
681
682
683
684
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
685
        signature_columns = self._signature_columns
686
687
688
689
690
691
692
693
694
695

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

696
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
697
        if self.train_dataset is None or not has_length(self.train_dataset):
698
            return None
699

700
701
702
        generator = None
        if self.args.world_size <= 1 and _is_torch_generator_available:
            generator = torch.Generator()
703
704
705
706
707
708
709
710
711
712
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
713

714
715
        # Build the sampler.
        if self.args.group_by_length:
716
717
718
719
720
721
722
723
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
724
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
725
            if self.args.world_size <= 1:
726
                return LengthGroupedSampler(
727
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
728
                    dataset=self.train_dataset,
729
730
731
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
732
                )
733
734
            else:
                return DistributedLengthGroupedSampler(
735
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
736
                    dataset=self.train_dataset,
737
738
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
739
                    lengths=lengths,
740
                    model_input_name=model_input_name,
741
                    seed=seed,
742
743
744
                )

        else:
745
            if self.args.world_size <= 1:
746
747
                if _is_torch_generator_available:
                    return RandomSampler(self.train_dataset, generator=generator)
748
                return RandomSampler(self.train_dataset)
Sylvain Gugger's avatar
Sylvain Gugger committed
749
750
751
752
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
753
754
755
756
757
758
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
759
                    seed=seed,
760
                )
761
            else:
762
                return DistributedSampler(
763
764
765
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
766
                    seed=seed,
767
                )
768
769
770

    def get_train_dataloader(self) -> DataLoader:
        """
771
        Returns the training [`~torch.utils.data.DataLoader`].
772

773
774
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
775
776
777
778
779

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
780

781
        train_dataset = self.train_dataset
782
        data_collator = self.data_collator
783
784
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
785
786
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
787

788
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
789
790
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
791
                    train_dataset,
792
                    batch_size=self._train_batch_size,
793
794
795
796
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
797

798
799
            return DataLoader(
                train_dataset,
800
                batch_size=self.args.per_device_train_batch_size,
801
                collate_fn=data_collator,
802
803
804
805
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

806
807
808
        train_sampler = self._get_train_sampler()

        return DataLoader(
809
            train_dataset,
810
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
811
            sampler=train_sampler,
812
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
813
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
814
            num_workers=self.args.dataloader_num_workers,
815
            pin_memory=self.args.dataloader_pin_memory,
816
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
817
818
        )

819
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            elif self.args.local_rank != -1:
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
842
843
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
844
845
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
846
            )
Lysandre Debut's avatar
Lysandre Debut committed
847

Julien Chaumond's avatar
Julien Chaumond committed
848
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
849
        """
850
        Returns the evaluation [`~torch.utils.data.DataLoader`].
851

852
853
        Subclass and override this method if you want to inject some custom behavior.

854
        Args:
855
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
856
857
                If provided, will override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not accepted by
                the `model.forward()` method are automatically removed. It must implement `__len__`.
858
        """
Julien Chaumond's avatar
Julien Chaumond committed
859
860
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
861
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
862
        data_collator = self.data_collator
863

864
865
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
866
867
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
868

869
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
870
871
872
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
873
                    batch_size=self.args.per_device_eval_batch_size,
874
875
876
877
878
879
880
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
881
                collate_fn=data_collator,
882
883
884
885
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

886
        eval_sampler = self._get_eval_sampler(eval_dataset)
887

888
        return DataLoader(
889
            eval_dataset,
890
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
891
            batch_size=self.args.eval_batch_size,
892
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
893
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
894
            num_workers=self.args.dataloader_num_workers,
895
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
896
897
898
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
899
        """
900
        Returns the test [`~torch.utils.data.DataLoader`].
901

902
903
        Subclass and override this method if you want to inject some custom behavior.

904
        Args:
905
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
906
907
                The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
                method are automatically removed. It must implement `__len__`.
908
        """
909
910
        data_collator = self.data_collator

911
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
912
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
913
914
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
915

916
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
917
918
919
920
921
922
923
924
925
926
927
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
928
                collate_fn=data_collator,
929
930
931
932
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

933
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
934

935
936
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
937
            test_dataset,
938
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
939
            batch_size=self.args.eval_batch_size,
940
            collate_fn=data_collator,
941
            drop_last=self.args.dataloader_drop_last,
942
            num_workers=self.args.dataloader_num_workers,
943
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
944
        )
Lysandre Debut's avatar
Lysandre Debut committed
945

946
    def create_optimizer_and_scheduler(self, num_training_steps: int):
947
948
949
        """
        Setup the optimizer and the learning rate scheduler.

950
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
951
952
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
953
954
        """
        self.create_optimizer()
955
956
957
958
        self.create_scheduler(
            num_training_steps=num_training_steps,
            optimizer=self.optimizer.optimizer if is_sagemaker_mp_enabled() and smp.state.cfg.fp16 else self.optimizer,
        )
959
960
961
962
963

    def create_optimizer(self):
        """
        Setup the optimizer.

964
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
965
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
966
        """
967
968
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

969
        if self.optimizer is None:
970
            decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
971
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
972
973
            optimizer_grouped_parameters = [
                {
974
                    "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters],
975
976
977
                    "weight_decay": self.args.weight_decay,
                },
                {
978
                    "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters],
979
980
981
                    "weight_decay": 0.0,
                },
            ]
982
983
984

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

985
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
986
987
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
988
989
                    optim=optimizer_cls,
                    **optimizer_kwargs,
990
991
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
992
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
993
994
995
996
997
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

998
                    for module in opt_model.modules():
999
1000
1001
                        if isinstance(module, nn.Embedding):
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
Sylvain Gugger's avatar
Sylvain Gugger committed
1002

Sylvain Gugger's avatar
Sylvain Gugger committed
1003
1004
1005
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1006
1007
        return self.optimizer

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
        optimizer_kwargs = {"lr": args.learning_rate}
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
        elif args.optim == OptimizerNames.ADAMW_TORCH:
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1036
1037
1038
1039
1040
1041
1042
1043
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1044
1045
1046
1047
1048
1049
1050
1051
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1052
1053
1054
1055
1056
1057
1058
1059
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1060
1061
1062
1063
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1064
1065
1066
1067
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1068
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1069
        """
1070
1071
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1072
1073
1074
1075

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1076
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1077
1078
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1079
                optimizer=self.optimizer if optimizer is None else optimizer,
1080
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1081
                num_training_steps=num_training_steps,
1082
            )
1083
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1084

1085
    def num_examples(self, dataloader: DataLoader) -> int:
1086
        """
1087
1088
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1089
        """
1090
1091
1092
1093
        try:
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1094

1095
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1096
        """HP search setup code"""
1097
1098
        self._trial = trial

1099
1100
        if self.hp_search_backend is None or trial is None:
            return
1101
1102
1103
1104
1105
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1106
1107
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1108
1109
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1110

1111
1112
        for key, value in params.items():
            if not hasattr(self.args, key):
1113
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1114
1115
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1116
                )
1117
                continue
1118
1119
1120
1121
1122
1123
1124
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)
1125
1126
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1127
1128
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1129
1130
        if self.args.deepspeed:
            # Rebuild the deepspeed config to reflect the updated training parameters
1131
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1132

1133
1134
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1135
1136
1137
1138
1139
1140

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
1141
        self.objective = self.compute_objective(metrics.copy())
1142
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1143
1144
            import optuna

1145
1146
            trial.report(self.objective, epoch)
            if trial.should_prune():
1147
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1148
1149
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1150
1151
            from ray import tune

1152
            if self.control.should_save:
1153
                self._tune_save_checkpoint()
1154
1155
            tune.report(objective=self.objective, **metrics)

1156
    def _tune_save_checkpoint(self):
1157
1158
        from ray import tune

1159
1160
        if not self.use_tune_checkpoints:
            return
1161
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1162
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1163
            self.save_model(output_dir, _internal_call=True)
1164
            if self.args.should_save:
1165
1166
1167
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1168

1169
    def call_model_init(self, trial=None):
1170
        model_init_argcount = number_of_arguments(self.model_init)
1171
1172
1173
1174
1175
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1176
1177
1178
1179
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1180
1181
1182

        return model

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            jit_inputs = []
            example_batch = next(iter(dataloader))
            for key in example_batch:
                example_tensor = torch.ones_like(example_batch[key])
                jit_inputs.append(example_tensor)
            jit_inputs = tuple(jit_inputs)
            try:
                jit_model = model.eval()
                with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
                    jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
                jit_model = torch.jit.freeze(jit_model)
                jit_model(**example_batch)
                model = jit_model
            except (RuntimeError, TypeError) as e:
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
                "Using IPEX but IPEX is not installed, please refer to"
                " https://github.com/intel/intel-extension-for-pytorch."
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
            model = ipex.optimize(model, dtype=dtype, level="O1")
        else:
            if not model.training:
                model.train()
            model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")

        return model

1225
    def _wrap_model(self, model, training=True, dataloader=None):
1226
1227
1228
1229
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

1230
1231
1232
        if self.args.jit_mode_eval:
            model = self.torch_jit_model_eval(model, dataloader, training)

Sylvain Gugger's avatar
Sylvain Gugger committed
1233
1234
1235
1236
1237
1238
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1239
1240
        # already initialized its own DDP and AMP
        if self.deepspeed:
1241
            return self.deepspeed
1242

1243
1244
1245
1246
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1247
1248
1249
1250
1251
1252
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
1253
            model = nn.DataParallel(model)
1254
1255
1256
1257
1258
1259
1260

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1261
1262
1263
1264
1265
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1266
                mixed_precision = self.args.fp16 or self.args.bf16
1267
1268
1269
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1270
1271
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1272
                self.model = model = FullyShardedDDP(
1273
1274
1275
1276
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1277
1278
                ).to(self.args.device)

1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
        # Distributed training using PyTorch FSDP
        if self.fsdp is not None:
            # PyTorch FSDP!
            from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
            from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
            from torch.distributed.fsdp.wrap import default_auto_wrap_policy

            if FSDPOption.OFFLOAD in self.args.fsdp:
                cpu_offload = CPUOffload(offload_params=True)
            else:
                cpu_offload = CPUOffload(offload_params=False)

            auto_wrap_policy = None
            if FSDPOption.AUTO_WRAP in self.args.fsdp:
                if self.args.fsdp_min_num_params > 0:
                    auto_wrap_policy = functools.partial(
                        default_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
                    )

            if type(model) != FSDP:
                # XXX: Breaking the self.model convention but I see no way around it for now.
                self.model = model = FSDP(
                    model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
                )
                if FSDPOption.OFFLOAD not in self.args.fsdp:
                    model.to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1306
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1307
1308
1309
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1310
        elif self.args.local_rank != -1:
1311
            kwargs = {}
1312
            if self.args.ddp_find_unused_parameters is not None:
1313
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1314
1315
1316
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1317
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1318
            else:
1319
1320
1321
1322
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1323
            model = nn.parallel.DistributedDataParallel(
1324
                model,
1325
1326
                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
1327
                **kwargs,
1328
1329
1330
1331
            )

        return model

1332
1333
    def train(
        self,
1334
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1335
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1336
        ignore_keys_for_eval: Optional[List[str]] = None,
1337
        **kwargs,
1338
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1339
1340
1341
1342
        """
        Main training entry point.

        Args:
1343
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1344
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1345
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1346
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1347
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1348
                The trial run or the hyperparameter dictionary for hyperparameter search.
1349
            ignore_keys_for_eval (`List[str]`, *optional*)
1350
1351
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1352
1353
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1354
        """
1355
1356
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1357
1358
1359
1360

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1361
1362
        args = self.args

1363
1364
        self.is_in_train = True

1365
1366
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1367
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1368
            self._move_model_to_device(self.model, args.device)
1369

1370
1371
1372
1373
1374
1375
1376
1377
1378
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1379
1380
1381
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

1382
        # Model re-init
1383
        model_reloaded = False
1384
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1385
            # Seed must be set before instantiating the model when using model_init.
1386
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1387
1388
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1389
1390
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1391

1392
        # Load potential model checkpoint
1393
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1394
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1395
            if resume_from_checkpoint is None:
1396
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1397

1398
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
1399
            self._load_from_checkpoint(resume_from_checkpoint)
1400

1401
1402
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1403
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1404
                self._move_model_to_device(self.model, args.device)
1405
1406
            self.model_wrapped = self.model

1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self._train_batch_size = batch_size
1421
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1422
        train_dataloader = self.get_train_dataloader()
1423
1424
1425
1426
1427

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1428
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1429
1430
1431
1432
1433

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1434
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1435
            num_examples = self.num_examples(train_dataloader)
1436
1437
1438
1439
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1440
                )
1441
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1442
1443
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1444
            else:
1445
1446
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1447
1448
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1449
            max_steps = args.max_steps
1450
1451
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1452
            num_update_steps_per_epoch = max_steps
1453
            num_examples = total_train_batch_size * args.max_steps
1454
            num_train_samples = args.max_steps * total_train_batch_size
1455
1456
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1457
1458
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1459
            )
Julien Chaumond's avatar
Julien Chaumond committed
1460

1461
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1462
1463
1464
1465
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1466
1467
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1468
1469
1470
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1471

1472
        delay_optimizer_creation = (
1473
1474
1475
1476
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1477
        )
1478
        if args.deepspeed:
1479
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
1480
1481
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
1482
1483
1484
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
1485
1486
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
1487
        elif not delay_optimizer_creation:
1488
1489
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1490
        self.state = TrainerState()
1491
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1492

1493
1494
1495
1496
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1497
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1498

1499
1500
1501
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1502
1503
1504
1505
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1506
1507
1508
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1509
1510
1511
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1512
1513
        # important: at this point:
        # self.model         is the Transformers Model
1514
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1515

Julien Chaumond's avatar
Julien Chaumond committed
1516
1517
        # Train!
        logger.info("***** Running training *****")
1518
1519
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
1520
        logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
1521
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
1522
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1523
        logger.info(f"  Total optimization steps = {max_steps}")
Julien Chaumond's avatar
Julien Chaumond committed
1524

1525
        self.state.epoch = 0
1526
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1527
1528
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1529
        steps_trained_progress_bar = None
1530

Julien Chaumond's avatar
Julien Chaumond committed
1531
        # Check if continuing training from a checkpoint
1532
        if resume_from_checkpoint is not None and os.path.isfile(
1533
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1534
        ):
1535
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1536
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1537
            if not args.ignore_data_skip:
1538
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1539
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1540
1541
            else:
                steps_trained_in_current_epoch = 0
1542
1543

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1544
1545
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1546
            if not args.ignore_data_skip:
1547
1548
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
1549
1550
                    "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
                    "flag to your launch command, but you will resume the training on data already seen by your model."
1551
                )
1552
1553
1554
                if self.is_local_process_zero() and not args.disable_tqdm:
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1555

Sylvain Gugger's avatar
Sylvain Gugger committed
1556
1557
1558
1559
1560
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1561
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
1562
1563
1564
1565
1566
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1567
1568
1569
1570
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1571
1572
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1573

1574
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1575
        tr_loss = torch.tensor(0.0).to(args.device)
1576
1577
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1578
        self._globalstep_last_logged = self.state.global_step
Julien Chaumond's avatar
Julien Chaumond committed
1579
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1580

1581
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1582

1583
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1584
        if not args.ignore_data_skip:
1585
            for epoch in range(epochs_trained):
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
                if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler:
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1598

1599
        for epoch in range(epochs_trained, num_train_epochs):
1600
1601
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1602
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1603
                train_dataloader.dataset.set_epoch(epoch)
1604

1605
            if is_torch_tpu_available():
1606
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1607
                epoch_iterator = parallel_loader
1608
            else:
1609
                epoch_iterator = train_dataloader
1610

1611
            # Reset the past mems state at the beginning of each epoch if necessary.
1612
            if args.past_index >= 0:
1613
1614
                self._past = None

1615
            steps_in_epoch = (
1616
1617
1618
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1619
            )
1620
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1621

1622
1623
1624
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1625
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1626
1627
1628
1629
1630
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1631
1632
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1633
1634
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1635
                    continue
1636
1637
1638
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1639

1640
1641
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1642

1643
                if (
1644
1645
1646
                    ((step + 1) % args.gradient_accumulation_steps != 0)
                    and args.local_rank != -1
                    and args._no_sync_in_gradient_accumulation
1647
                ):
1648
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1649
                    with model.no_sync():
1650
                        tr_loss_step = self.training_step(model, inputs)
1651
                else:
1652
1653
                    tr_loss_step = self.training_step(model, inputs)

1654
1655
1656
1657
1658
1659
1660
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1661
1662
1663
                else:
                    tr_loss += tr_loss_step

1664
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1665

1666
1667
1668
1669
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

1670
                if (step + 1) % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1671
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1672
                    steps_in_epoch <= args.gradient_accumulation_steps
1673
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1674
                ):
1675
                    # Gradient clipping
1676
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
1677
1678
                        # deepspeed does its own clipping

1679
                        if self.do_grad_scaling:
1680
1681
1682
1683
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1684
1685
1686
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1687
1688
1689
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1690
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1691
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1692
1693
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1694
                            model.clip_grad_norm_(args.max_grad_norm)
1695
1696
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
1697
                            nn.utils.clip_grad_norm_(
1698
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
1699
                                args.max_grad_norm,
1700
1701
1702
                            )

                    # Optimizer step
1703
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
1704
                    if self.deepspeed:
1705
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
1706
                    elif is_torch_tpu_available():
1707
1708
1709
1710
1711
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1712
                    elif self.do_grad_scaling:
1713
                        scale_before = self.scaler.get_scale()
1714
                        self.scaler.step(self.optimizer)
1715
                        self.scaler.update()
1716
1717
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1718
                    else:
1719
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
1720

1721
                    if optimizer_was_run and not self.deepspeed:
1722
1723
                        self.lr_scheduler.step()

Julien Chaumond's avatar
Julien Chaumond committed
1724
                    model.zero_grad()
1725
                    self.state.global_step += 1
1726
                    self.state.epoch = epoch + (step + 1) / steps_in_epoch
1727
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1728

1729
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
1730
1731
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1732

Sylvain Gugger's avatar
Sylvain Gugger committed
1733
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1734
                    break
1735
1736
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1737
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
1738
1739
1740
1741
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
1742

1743
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
1744
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
1745

1746
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
1747
1748
1749
1750
1751
1752
1753
1754
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1755
            if self.control.should_training_stop:
1756
                break
Julien Chaumond's avatar
Julien Chaumond committed
1757

1758
        if args.past_index and hasattr(self, "_past"):
1759
1760
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1761
1762

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
1763
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
1764
1765
1766
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
1767
            elif args.local_rank != -1:
1768
                dist.barrier()
1769
1770
            elif is_sagemaker_mp_enabled():
                smp.barrier()
1771

1772
            self._load_best_model()
1773

1774
1775
1776
1777
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

1778
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
1779
1780
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
1781
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
1782

1783
        self.is_in_train = False
1784

1785
1786
        self._memory_tracker.stop_and_update_metrics(metrics)

1787
1788
1789
1790
1791
        self.log(metrics)

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
1792

1793
1794
1795
1796
1797
1798
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):

        if model is None:
            model = self.model
        strict_load = is_sagemaker_mp_enabled()

1799
1800
1801
        if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
            os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        ):
1802
1803
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

1804
        logger.info(f"Loading model from {resume_from_checkpoint}.")
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818

        if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
            config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

        if self.args.deepspeed:
            # will be resumed in deepspeed_init
            pass
1819
        elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
1820
1821
1822
            # We load the model state dict on the CPU to avoid an OOM error.
            state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
            # If the model is on the GPU, it still works!
1823
1824
1825
            load_result = model.load_state_dict(state_dict, strict=strict_load)
            if not strict_load:
                self._issue_warnings_after_load(load_result)
1826
1827
            # release memory
            del state_dict
1828
1829
        else:
            # We load the sharded checkpoint
1830
1831
1832
            load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
            if not strict_load:
                self._issue_warnings_after_load(load_result)
1833
1834
1835
1836

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
1837
1838
        strict_load = is_sagemaker_mp_enabled()
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
1839
1840
        if os.path.exists(best_model_path):
            if self.deepspeed:
1841
1842
1843
1844
1845
1846

                if self.model_wrapped is not None:
                    # this removes the pre-hooks from the previous engine
                    self.model_wrapped.destroy()
                    self.model_wrapped = None

1847
                # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
1848
1849
1850
1851
1852
                deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
                    self,
                    num_training_steps=self.args.max_steps,
                    resume_from_checkpoint=self.state.best_model_checkpoint,
                )
1853
1854
1855
1856
1857
1858
1859
1860
1861
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
                self.optimizer = optimizer
                self.lr_scheduler = lr_scheduler
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(best_model_path, map_location="cpu")
                # If the model is on the GPU, it still works!
1862
1863
1864
                load_result = model.load_state_dict(state_dict, strict=strict_load)
                if not strict_load:
                    self._issue_warnings_after_load(load_result)
1865
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
1866
1867
1868
            load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
            if not strict_load:
                self._issue_warnings_after_load(load_result)
1869
1870
1871
1872
1873
1874
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

1875
    def _issue_warnings_after_load(self, load_result):
1876
1877

        if len(load_result.missing_keys) != 0:
1878
1879
1880
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
1881
1882
                self.model.tie_weights()
            else:
1883
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
1884
        if len(load_result.unexpected_keys) != 0:
1885
1886
1887
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
1888

1889
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
1890
        if self.control.should_log:
1891
1892
1893
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
1894
            logs: Dict[str, float] = {}
1895
1896
1897
1898

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

1899
1900
1901
            # reset tr_loss to zero
            tr_loss -= tr_loss

1902
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
1903
            logs["learning_rate"] = self._get_learning_rate()
1904

1905
            self._total_loss_scalar += tr_loss_scalar
1906
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
1907
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
1908
1909
1910
1911
1912

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
1913
            metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
Sylvain Gugger's avatar
Sylvain Gugger committed
1914
            self._report_to_hp_search(trial, epoch, metrics)
1915

Sylvain Gugger's avatar
Sylvain Gugger committed
1916
1917
1918
1919
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

1920
1921
1922
1923
1924
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

1925
1926
1927
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
1928
            if not os.path.isfile(rng_file):
1929
                logger.info(
1930
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
1931
1932
1933
1934
1935
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
1936
            if not os.path.isfile(rng_file):
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
            if self.args.local_rank != -1:
                torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
            else:
1951
1952
1953
                try:
                    torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
                except Exception as e:
1954
                    logger.info(
1955
1956
1957
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
1958
1959
1960
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
1961
    def _save_checkpoint(self, model, trial, metrics=None):
1962
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
1963
        # want to save except FullyShardedDDP.
1964
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
1965

1966
        # Save model checkpoint
1967
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1968

1969
        if self.hp_search_backend is not None and trial is not None:
1970
1971
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
1972
            elif self.hp_search_backend == HPSearchBackend.RAY:
1973
1974
1975
                from ray import tune

                run_id = tune.get_trial_id()
1976
1977
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
1978
1979
1980
1981
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
1982
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
1983
            run_dir = os.path.join(self.args.output_dir, run_name)
1984
        else:
1985
            run_dir = self.args.output_dir
1986
            self.store_flos()
1987

1988
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
1989
        self.save_model(output_dir, _internal_call=True)
1990
        if self.deepspeed:
1991
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
1992
            # config `stage3_gather_16bit_weights_on_model_save` is True
1993
            self.deepspeed.save_checkpoint(output_dir)
1994
1995

        # Save optimizer and scheduler
1996
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1997
            self.optimizer.consolidate_state_dict()
1998

1999
2000
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2001
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2002
            with warnings.catch_warnings(record=True) as caught_warnings:
2003
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2004
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2005
        elif is_sagemaker_mp_enabled():
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2021
        elif self.args.should_save and not self.deepspeed:
2022
            # deepspeed.save_checkpoint above saves model/optim/sched
2023
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2024
            with warnings.catch_warnings(record=True) as caught_warnings:
2025
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2026
            reissue_pt_warnings(caught_warnings)
2027
            if self.do_grad_scaling:
2028
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2029
2030

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2031
        if metrics is not None and self.args.metric_for_best_model is not None:
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2047
        if self.args.should_save:
2048
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2049

2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.local_rank == -1:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2066
2067
2068
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2069

2070
        if self.args.world_size <= 1:
2071
2072
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2073
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2074

2075
2076
2077
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2078
        # Maybe delete some older checkpoints.
2079
        if self.args.should_save:
2080
2081
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2082
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2083
        """If optimizer and scheduler states exist, load them."""
2084
        if checkpoint is None:
2085
2086
            return

2087
        if self.deepspeed:
2088
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2089
2090
            return

2091
2092
2093
2094
2095
2096
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2097
2098
2099
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2100
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2101
                with warnings.catch_warnings(record=True) as caught_warnings:
2102
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2103
2104
2105
2106
2107
2108
2109
2110
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2111
                map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
2112
2113
2114
                if is_sagemaker_mp_enabled():

                    def opt_load_hook(mod, opt):
2115
2116
2117
                        opt.load_state_dict(
                            smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True), gather_if_shard=False
                        )
2118
2119
2120
2121
2122
2123

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
                    self.optimizer.load_state_dict(
                        torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2124
                with warnings.catch_warnings(record=True) as caught_warnings:
2125
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2126
                reissue_pt_warnings(caught_warnings)
2127
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2128
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2129

2130
2131
2132
2133
2134
2135
2136
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2137
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2138
        **kwargs,
2139
2140
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2141
2142
2143
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2144

2145
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2146

Sylvain Gugger's avatar
Sylvain Gugger committed
2147
2148
2149
2150
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2151
2152

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2153

2154
        Args:
2155
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2156
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2157
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2158
2159
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2160
2161
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2162
            n_trials (`int`, *optional*, defaults to 100):
2163
                The number of trial runs to test.
2164
            direction(`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2165
2166
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2167
            backend(`str` or [`~training_utils.HPSearchBackend`], *optional*):
2168
2169
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2170
            kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2171
2172
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2173

Sylvain Gugger's avatar
Sylvain Gugger committed
2174
2175
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2176
2177
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2178
2179

        Returns:
2180
            [`trainer_utils.BestRun`]: All the information about the best run.
2181
2182
2183
2184
2185
2186
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2187
2188
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2189
                    "To install sigopt run `pip install sigopt`."
2190
2191
2192
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2193
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2194
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2195
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2196
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2197
            )
2198
2199
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2200
2201
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2202
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2203
2204
2205
2206
2207
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2208
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2209
        self.hp_name = hp_name
2210
2211
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2212
2213
2214
2215
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2216
            HPSearchBackend.WANDB: run_hp_search_wandb,
2217
2218
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2219
2220
2221
2222

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2223
    def log(self, logs: Dict[str, float]) -> None:
2224
        """
2225
        Log `logs` on the various objects watching training.
2226
2227
2228
2229

        Subclass and override this method to inject custom behavior.

        Args:
2230
            logs (`Dict[str, float]`):
2231
2232
                The values to log.
        """
2233
        if self.state.epoch is not None:
2234
            logs["epoch"] = round(self.state.epoch, 2)
2235

2236
2237
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2238
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2239

2240
2241
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2242
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2243
        """
2244
2245
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
            kwargs = dict(device=self.args.device)
            if self.deepspeed and data.dtype != torch.int64:
                # NLP models inputs are int64 and those get adjusted to the right dtype of the
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
                kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2258
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2259
        """
2260
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2261
2262
        handling potential state.
        """
2263
        inputs = self._prepare_input(inputs)
2264
2265
2266
2267
2268
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2269
2270
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2271

2272
2273
        return inputs

2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
        return ContextManagers(
            [
                self.torchdynamo_smart_context_manager(),
                self.autocast_smart_context_manager(),
            ]
        )

    def torchdynamo_smart_context_manager(self):
        """
        A helper wrapper that creates an appropriate context manager for `torchdynamo`.
        """
        ctx_manager = contextlib.nullcontext()
        if is_torchdynamo_available():
            import torchdynamo
            from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

            if self.args.torchdynamo == "eager":
                ctx_manager = torchdynamo.optimize("eager")
            elif self.args.torchdynamo == "nvfuser":
                ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
        return ctx_manager

2300
2301
    def autocast_smart_context_manager(self):
        """
2302
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2303
2304
        arguments, depending on the situation.
        """
2305
        if self.use_cuda_amp or self.use_cpu_amp:
2306
            if version.parse(torch.__version__) >= version.parse("1.10"):
2307
2308
2309
2310
2311
                ctx_manager = (
                    torch.cpu.amp.autocast(dtype=self.amp_dtype)
                    if self.use_cpu_amp
                    else torch.cuda.amp.autocast(dtype=self.amp_dtype)
                )
2312
            else:
2313
                ctx_manager = torch.cuda.amp.autocast()
2314
2315
2316
2317
2318
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2319
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2320
        """
2321
        Perform a training step on a batch of inputs.
2322
2323
2324
2325

        Subclass and override to inject custom behavior.

        Args:
2326
            model (`nn.Module`):
2327
                The model to train.
2328
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2329
2330
2331
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2332
                argument `labels`. Check your model's documentation for all accepted arguments.
2333
2334

        Return:
2335
            `torch.Tensor`: The tensor with training loss on this batch.
2336
2337
        """
        model.train()
2338
        inputs = self._prepare_inputs(inputs)
2339

Sylvain Gugger's avatar
Sylvain Gugger committed
2340
        if is_sagemaker_mp_enabled():
2341
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2342
2343
            return loss_mb.reduce_mean().detach().to(self.args.device)

2344
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2345
            loss = self.compute_loss(model, inputs)
2346

Julien Chaumond's avatar
Julien Chaumond committed
2347
2348
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2349

2350
2351
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
2352
2353
            loss = loss / self.args.gradient_accumulation_steps

2354
        if self.do_grad_scaling:
2355
            self.scaler.scale(loss).backward()
2356
        elif self.use_apex:
2357
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2358
                scaled_loss.backward()
2359
        elif self.deepspeed:
2360
2361
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2362
2363
2364
        else:
            loss.backward()

2365
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
2366

2367
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2368
2369
2370
2371
2372
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2373
2374
2375
2376
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2377
2378
        outputs = model(**inputs)
        # Save past state if it exists
2379
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2380
2381
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2382

2383
        if labels is not None:
2384
            loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2385
2386
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2387
2388
2389
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2390

2391
2392
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2393
2394
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2395
        """
2396
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2397

2398
2399
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2400
        Whether or not this process is the global main process (when training in a distributed fashion on several
2401
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2402
        """
2403
2404
2405
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2406
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2407
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2408
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2409

Sylvain Gugger's avatar
Sylvain Gugger committed
2410
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2411
        """
2412
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2413

2414
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2415
        """
2416
2417
2418
2419

        if output_dir is None:
            output_dir = self.args.output_dir

2420
        if is_torch_tpu_available():
2421
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2422
2423
2424
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
            state_dict = self.model_wrapped.state_dict()
2425
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2426
                self._save(output_dir, state_dict=state_dict)
2427
        elif (
2428
2429
2430
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2431
2432
        ):
            state_dict = self.model.state_dict()
2433

2434
            if self.args.should_save:
2435
                self._save(output_dir, state_dict=state_dict)
2436
2437
2438
        elif self.deepspeed:

            # this takes care of everything as long as we aren't under zero3
2439
            if self.args.should_save:
2440
2441
2442
2443
2444
2445
2446
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2447
                if self.args.should_save:
2448
2449
2450
2451
2452
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2453
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2454
2455
                # if false it will not be saved.
                # This must be called on all ranks
2456
                if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
2457
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2458
2459
2460
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2461
2462
                    )
                    self.deepspeed.save_checkpoint(output_dir)
2463

2464
        elif self.args.should_save:
2465
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2466

Sylvain Gugger's avatar
Sylvain Gugger committed
2467
2468
2469
2470
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2471
2472
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2473
        logger.info(f"Saving model checkpoint to {output_dir}")
2474
2475
2476

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2477
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2478
2479
2480
2481

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2482
        if not isinstance(self.model, PreTrainedModel):
2483
2484
2485
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2486
                    is_main_process=self.args.should_save,
2487
2488
2489
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2490
2491
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2492
2493
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2494
        else:
2495
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2496
        if self.tokenizer is not None and self.args.should_save:
2497
            self.tokenizer.save_pretrained(output_dir)
2498

2499
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2500
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2501
2502
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2503
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2504
2505
2506
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2507
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2508
2509
2510
                if state_dict is None:
                    state_dict = self.model.state_dict()
                unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
2511
2512
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2513
2514
                if state_dict is None:
                    state_dict = self.model.state_dict()
2515
                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2516
        else:
2517
            self.model.save_pretrained(output_dir, state_dict=state_dict)
2518
        if self.tokenizer is not None:
2519
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2520
2521

        # Good practice: save your training arguments together with the trained model
2522
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2523

2524
    def store_flos(self):
2525
        # Storing the number of floating-point operations that went into the model
2526
        if self.args.local_rank != -1:
2527
2528
2529
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2530
2531
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2532
            self.state.total_flos += self.current_flos
2533
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2534

2535
2536
2537
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2538
2539
        ordering_and_checkpoint_path = []

2540
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2541
2542
2543
2544
2545

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2546
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2547
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2548
2549
2550
2551
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2552
2553
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2554
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2555
2556
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2557
2558
        return checkpoints_sorted

2559
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2560
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2561
2562
2563
            return

        # Check if we should delete older checkpoint(s)
2564
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2565
2566
2567
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2568
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2579
2580
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2581
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
Julien Chaumond's avatar
Julien Chaumond committed
2582
2583
            shutil.rmtree(checkpoint)

2584
    def evaluate(
2585
2586
2587
2588
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2589
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2590
        """
2591
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2592

Sylvain Gugger's avatar
Sylvain Gugger committed
2593
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2594
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2595

2596
2597
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2598
        Args:
2599
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2600
2601
2602
                Pass a dataset if you wish to override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not
                accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
                method.
2603
            ignore_keys (`Lst[str]`, *optional*):
2604
2605
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2606
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2607
2608
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2609

Julien Chaumond's avatar
Julien Chaumond committed
2610
        Returns:
2611
2612
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2613
        """
2614
2615
2616
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2617
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2618
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2619

2620
2621
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2622
2623
2624
2625
2626
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
2627
            ignore_keys=ignore_keys,
2628
            metric_key_prefix=metric_key_prefix,
2629
        )
Lysandre Debut's avatar
Lysandre Debut committed
2630

2631
2632
2633
2634
2635
2636
2637
2638
2639
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2640

2641
        self.log(output.metrics)
2642

2643
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
2644
2645
2646
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
2647
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
2648
2649
2650

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
2651
2652
        return output.metrics

2653
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
2654
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
2655
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2656
        """
2657
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2658

Sylvain Gugger's avatar
Sylvain Gugger committed
2659
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
2660
        will also return metrics, like in `evaluate()`.
2661
2662

        Args:
2663
2664
2665
2666
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
            ignore_keys (`Lst[str]`, *optional*):
2667
2668
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2669
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
2670
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
2671
                "test_bleu" if the prefix is "test" (default)
2672

2673
2674
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
2675
2676
2677
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
2678

2679
        </Tip>
2680

2681
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
2682

2683
2684
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
2685
2686
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
2687
        """
2688
2689
2690
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2691
        test_dataloader = self.get_test_dataloader(test_dataset)
2692
        start_time = time.time()
2693

2694
2695
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2696
2697
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
2698
2699
2700
2701
2702
2703
2704
2705
2706
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2707
2708
2709

        self._memory_tracker.stop_and_update_metrics(output.metrics)

2710
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
2711

2712
    def evaluation_loop(
2713
2714
2715
2716
2717
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
2718
        metric_key_prefix: str = "eval",
2719
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2720
        """
2721
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
2722
2723
2724

        Works both with or without labels.
        """
2725
2726
2727
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
2728

2729
        # if eval is called w/o train init deepspeed here
2730
        if args.deepspeed and not self.deepspeed:
2731
2732
2733

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
2734
2735
2736
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
2737
2738
2739
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
2740

2741
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
2742

2743
2744
2745
2746
2747
2748
2749
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
2750

2751
        batch_size = self.args.eval_batch_size
2752

2753
        logger.info(f"***** Running {description} *****")
2754
        if has_length(dataloader):
2755
2756
2757
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
2758
        logger.info(f"  Batch size = {batch_size}")
2759

Julien Chaumond's avatar
Julien Chaumond committed
2760
2761
        model.eval()

2762
2763
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
2764
        eval_dataset = getattr(dataloader, "dataset", None)
2765

2766
        if is_torch_tpu_available():
2767
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
2768

2769
        if args.past_index >= 0:
2770
            self._past = None
2771

2772
2773
2774
2775
2776
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
2777
2778
        inputs_host = None

2779
2780
2781
2782
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
2783
        all_inputs = None
2784
2785
2786
2787
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
2788
        for step, inputs in enumerate(dataloader):
2789
2790
2791
2792
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
2793
2794
2795
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
2796
2797

            # Prediction step
2798
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
2799
            inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
2800

2801
2802
2803
            if is_torch_tpu_available():
                xm.mark_step()

2804
            # Update containers on host
2805
            if loss is not None:
2806
                losses = self._nested_gather(loss.repeat(batch_size))
2807
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
2808
            if labels is not None:
2809
2810
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
2811
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
2812
2813
2814
2815
2816
2817
2818
2819
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
2820
2821
2822
2823
2824
2825
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
2826
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2827

2828
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
2829
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
2830
2831
2832
2833
2834
2835
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
2836
2837
2838
2839
2840
2841
2842
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
2843
2844
2845
2846
2847
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
2848
2849

                # Set back to None to begin a new accumulation
2850
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
2851

2852
        if args.past_index and hasattr(self, "_past"):
2853
2854
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2855

2856
        # Gather all remaining tensors and put them back on the CPU
2857
2858
2859
2860
2861
2862
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
2863
2864
2865
2866
2867
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
2868
2869
2870
2871
2872
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
2873
        if has_length(eval_dataset):
2874
            num_samples = len(eval_dataset)
2875
2876
2877
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
2878
2879
            num_samples = eval_dataset.num_examples
        else:
2880
2881
2882
2883
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
2884
2885
2886
2887
2888
2889
2890
2891
2892

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
2893
2894
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
2895
2896
2897

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
2898
2899
2900
2901
2902
2903
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
2904
2905
        else:
            metrics = {}
2906

2907
2908
2909
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

2910
2911
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
2912

2913
        # Prefix all keys with metric_key_prefix + '_'
2914
        for key in list(metrics.keys()):
2915
2916
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
2917

2918
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
2919

2920
    def _nested_gather(self, tensors, name=None):
2921
2922
2923
2924
2925
2926
2927
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
2928
2929
            if name is None:
                name = "nested_gather"
2930
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
2931
2932
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
2933
2934
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)
2935
        return tensors
2936

2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
        if tensor.shape[1] == max_size:
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
2969

2970
    def prediction_step(
2971
2972
2973
2974
2975
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
2976
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
2977
        """
Stas Bekman's avatar
Stas Bekman committed
2978
        Perform an evaluation step on `model` using `inputs`.
2979
2980
2981
2982

        Subclass and override to inject custom behavior.

        Args:
2983
            model (`nn.Module`):
2984
                The model to evaluate.
2985
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2986
2987
2988
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2989
2990
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
2991
                Whether or not to return the loss only.
2992
            ignore_keys (`Lst[str]`, *optional*):
2993
2994
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2995
2996

        Return:
2997
2998
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
2999
        """
3000
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
3001
        inputs = self._prepare_inputs(inputs)
3002
3003
3004
3005
3006
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3007

3008
3009
3010
3011
3012
3013
3014
3015
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3016
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3029
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3030
3031
3032
3033
3034
3035
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3036
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3037
                if has_labels:
3038
                    with self.compute_loss_context_manager():
3039
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3040
                    loss = loss.mean().detach()
3041

Sylvain Gugger's avatar
Sylvain Gugger committed
3042
3043
3044
3045
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3046
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3047
                    loss = None
3048
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3049
3050
3051
3052
3053
3054
3055
3056
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3057
3058
3059
3060

        if prediction_loss_only:
            return (loss, None, None)

3061
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3062
3063
3064
3065
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3066
3067
3068

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3069
3070
3071
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3072
3073

        Args:
3074
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3075
3076
3077
                The inputs and targets of the model.

        Returns:
3078
            `int`: The number of floating-point operations.
3079
        """
3080
3081
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3082
3083
        else:
            return 0
3084

3085
    def init_git_repo(self, at_init: bool = False):
3086
        """
3087
        Initializes a git repo in `self.args.hub_model_id`.
3088
3089
3090
3091
3092
3093

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3094
        """
3095
        if not self.is_world_process_zero():
3096
            return
3097
3098
        use_auth_token = True if self.args.hub_token is None else self.args.hub_token
        if self.args.hub_model_id is None:
3099
            repo_name = Path(self.args.output_dir).absolute().name
3100
3101
        else:
            repo_name = self.args.hub_model_id
3102
3103
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3104

3105
3106
3107
3108
3109
        try:
            self.repo = Repository(
                self.args.output_dir,
                clone_from=repo_name,
                use_auth_token=use_auth_token,
3110
                private=self.args.hub_private_repo,
3111
3112
            )
        except EnvironmentError:
3113
            if self.args.overwrite_output_dir and at_init:
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
                self.repo = Repository(
                    self.args.output_dir,
                    clone_from=repo_name,
                    use_auth_token=use_auth_token,
                )
            else:
                raise

        self.repo.git_pull()
3125
3126

        # By default, ignore the checkpoint folders
3127
3128
3129
3130
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3131
3132
3133
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3134
3135
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3136
3137
3138
3139
3140
3141
3142
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
        tags: Optional[str] = None,
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3143
        tasks: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3144
3145
3146
3147
        dataset_tags: Optional[Union[str, List[str]]] = None,
        dataset: Optional[Union[str, List[str]]] = None,
        dataset_args: Optional[Union[str, List[str]]] = None,
    ):
3148
3149
3150
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3151
3152
3153
3154
3155
3156
3157
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3158
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3159
3160
3161
3162
3163
3164
3165
3166
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3201
3202
3203
            _, self.push_in_progress = self.repo.push_to_hub(
                commit_message=commit_message, blocking=False, auto_lfs_prune=True
            )
3204
3205
3206
3207
3208
3209
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3210
        """
3211
        Upload *self.model* and *self.tokenizer* to the 馃 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3212
3213

        Parameters:
3214
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3215
                Message to commit while pushing.
3216
3217
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3218
            kwargs:
3219
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3220
3221

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3222
3223
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3224
        """
3225
3226
3227
3228
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3229

3230
        if self.args.should_save:
3231
3232
3233
3234
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3235

3236
3237
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3238
        self.save_model(_internal_call=True)
3239
3240
3241
3242
3243

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3244
3245
3246
3247
3248
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3249
3250
3251
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3252
3253
3254
3255
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3256
3257
3258
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3259
3260
3261
3262
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3263

3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        """
3277
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3278
3279
3280

        Works both with or without labels.
        """
3281
3282
        args = self.args

3283
3284
3285
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3286
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3287
3288

        # if eval is called w/o train init deepspeed here
3289
        if args.deepspeed and not self.deepspeed:
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

3302
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3303

3304
3305
3306
3307
3308
3309
3310
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3311
3312
3313
3314
3315
3316
3317
3318
3319

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3320
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3321

3322
        world_size = max(1, args.world_size)
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3333
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3334
3335
3336
3337

        model.eval()

        if is_torch_tpu_available():
3338
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3339

3340
        if args.past_index >= 0:
3341
3342
3343
3344
3345
3346
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3347
3348
            inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None

3349
3350
3351
3352
3353
3354
3355
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3356
3357
3358
3359
3360
3361
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3362
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3363
3364

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3365
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3366
3367
3368
3369
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3370
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3371
3372

                # Set back to None to begin a new accumulation
3373
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3374

3375
        if args.past_index and hasattr(self, "_past"):
3376
3377
3378
3379
3380
3381
3382
3383
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3384
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3385
3386
3387
3388

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3389
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3390
3391

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3392
3393
3394
3395
3396
3397
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)