trainer.py 157 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 馃 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34

35
36
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
37

38
39
# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
40
    default_hp_search_backend,
41
    get_reporting_integration_callbacks,
42
    hp_params,
43
    is_fairscale_available,
44
    is_optuna_available,
45
    is_ray_tune_available,
46
    is_sigopt_available,
47
    is_wandb_available,
48
49
    run_hp_search_optuna,
    run_hp_search_ray,
50
    run_hp_search_sigopt,
51
    run_hp_search_wandb,
52
)
53
54
55

import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
56
import torch.distributed as dist
57
58
from packaging import version
from torch import nn
59
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
60
61
from torch.utils.data.distributed import DistributedSampler

62
63
from huggingface_hub import Repository

64
65
from . import __version__
from .configuration_utils import PretrainedConfig
66
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
67
from .debug_utils import DebugOption, DebugUnderflowOverflow
68
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
69
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
70
from .modelcard import TrainingSummary
71
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
72
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
73
from .optimization import Adafactor, get_scheduler
74
from .pytorch_utils import ALL_LAYERNORM_LAYERS
75
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
76
77
78
79
80
81
82
83
84
85
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
86
    DistributedLengthGroupedSampler,
87
    DistributedSamplerWithLoop,
88
    DistributedTensorGatherer,
89
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
90
    LabelSmoother,
91
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
92
    SequentialDistributedSampler,
93
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
    distributed_broadcast_scalars,
    distributed_concat,
96
    find_batch_size,
97
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
98
99
100
    nested_concat,
    nested_detach,
    nested_numpify,
101
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
102
103
104
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
105
106
107
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
108
    EvalLoopOutput,
109
    EvalPrediction,
110
    FSDPOption,
111
    HPSearchBackend,
112
113
    HubStrategy,
    IntervalStrategy,
114
    PredictionOutput,
115
    RemoveColumnsCollator,
116
    ShardedDDPOption,
117
    TrainerMemoryTracker,
118
119
120
    TrainOutput,
    default_compute_objective,
    default_hp_space,
121
    denumpify_detensorize,
122
    enable_full_determinism,
123
    find_executable_batch_size,
124
    get_last_checkpoint,
125
    has_length,
126
    number_of_arguments,
127
    seed_worker,
128
    set_seed,
129
    speed_metrics,
130
)
131
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
132
133
from .utils import (
    CONFIG_NAME,
134
    WEIGHTS_INDEX_NAME,
135
    WEIGHTS_NAME,
136
    find_labels,
137
138
139
140
    get_full_repo_name,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
141
    is_ipex_available,
142
143
144
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_tpu_available,
145
    is_torchdynamo_available,
146
147
    logging,
)
148
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
149
150


151
_is_torch_generator_available = False
152
153
_is_native_cuda_amp_available = False
_is_native_cpu_amp_available = False
154

Sylvain Gugger's avatar
Sylvain Gugger committed
155
DEFAULT_CALLBACKS = [DefaultFlowCallback]
156
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
157

158
159
160
161
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
162

163
164
if is_apex_available():
    from apex import amp
165

166
if version.parse(torch.__version__) >= version.parse("1.6"):
167
    _is_torch_generator_available = True
168
169
170
171
    _is_native_cuda_amp_available = True

if version.parse(torch.__version__) >= version.parse("1.10"):
    _is_native_cpu_amp_available = True
Julien Chaumond's avatar
Julien Chaumond committed
172

173
174
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
175

176
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
177
178
179
180
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

181
if is_fairscale_available():
182
    dep_version_check("fairscale")
183
    import fairscale
184
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
185
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
186
    from fairscale.nn.wrap import auto_wrap
187
188
189
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

190

Sylvain Gugger's avatar
Sylvain Gugger committed
191
192
193
194
195
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat

196

197
198
199
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
200
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
201
202


203
204
205
206
207
208
209
210
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
211
212
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
213
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 馃 Transformers.
214
215

    Args:
216
217
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
218

219
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
220

Sylvain Gugger's avatar
Sylvain Gugger committed
221
222
223
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 馃 Transformers
            models.
224
225
226
227

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
228
229
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
230
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
231
232
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
233
234
235
236
237
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
            The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
240
241
242
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
243
244
245
246
        eval_dataset (`torch.utils.data.Dataset`, *optional*):
             The dataset to use for evaluation. If it is an `datasets.Dataset`, columns not accepted by the
             `model.forward()` method are automatically removed.
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
247
248
249
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
250
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
251
252
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
253

254
255
256
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
257
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
258
259
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
260
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
261
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
262
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
263

264
265
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
266
267
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
268
269
270
271
272
273
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
274

275
276
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
277
278
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
279
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
280
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
281
282
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
283
284
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
285
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
286
287
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
288
289
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
290

Julien Chaumond's avatar
Julien Chaumond committed
291
292
    """

293
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
294

Julien Chaumond's avatar
Julien Chaumond committed
295
296
    def __init__(
        self,
297
        model: Union[PreTrainedModel, nn.Module] = None,
298
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
299
300
301
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
302
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
303
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
304
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
305
        callbacks: Optional[List[TrainerCallback]] = None,
306
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
307
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
Julien Chaumond's avatar
Julien Chaumond committed
308
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
309
        if args is None:
310
311
312
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
313
314
        self.args = args
        # Seed must be set before instantiating the model when using model
315
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
316
        self.hp_name = None
317
        self.deepspeed = None
318
        self.is_in_train = False
319

320
321
322
323
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

324
        # set the correct log level depending on the node
325
        log_level = args.get_process_log_level()
326
327
        logging.set_verbosity(log_level)

328
329
330
        # force device and distributed setup init explicitly
        args._setup_devices

331
332
333
334
335
336
337
338
339
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
340
341
342
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
343
344
345
                    FutureWarning,
                )
            self.model_init = model_init
346

347
348
349
350
351
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

352
353
354
355
356
357
358
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
359
360
361
362
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

380
381
382
383
384
385
386
387
388
        self.fsdp = None
        if len(args.fsdp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
            if args.local_rank == -1:
                raise ValueError("Using fsdp only works in distributed training.")

389
390
391
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
392
            # below is the current alternative.
393
394
            if version.parse(torch.__version__) < version.parse("1.12.0"):
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
395
396
397
398
399
400
401
402

            from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP

403
        # one place to sort out whether to place the model on device or not
404
405
406
407
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
408
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
409
        # 4. Sharded DDP - same as MP
410
        # 5. FSDP - same as MP
411
        self.place_model_on_device = args.place_model_on_device
412
413
        if (
            self.is_model_parallel
414
            or args.deepspeed
415
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
416
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
417
            or (self.fsdp is not None)
418
        ):
419
420
            self.place_model_on_device = False

421
422
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
423
424
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
425
        self.tokenizer = tokenizer
426

427
        if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
428
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
429
430
431

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
432
            self.args._n_gpu = 1
433
434
435
436
437

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
438
        self.compute_metrics = compute_metrics
439
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
440
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
441
442
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
443
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
444
445
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
446
        if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
447
448
449
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
450
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
451
452
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
453
454
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
455
456
457
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
458
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
459

460
461
462
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

463
464
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
465
            self.init_git_repo(at_init=True)
466
467
468
469
470
471
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
            elif args.local_rank != -1:
                dist.barrier()

472
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
473
            os.makedirs(self.args.output_dir, exist_ok=True)
474

475
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
476
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
477

478
479
480
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

481
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
482
483
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")

484
485
486
487
488
489
490
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

491
        self._signature_columns = None
492

493
494
        # Mixed precision setup
        self.use_apex = False
495
496
        self.use_cuda_amp = False
        self.use_cpu_amp = False
497

498
499
500
501
502
503
504
505
506
507
508
509
510
511
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
            # When there's mismatch between SMP config and trainer argument, use SMP config as truth
            if args.fp16 != smp.state.cfg.fp16:
                logger.warning(
                    f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                    f"but FP16 provided in trainer argument is {args.fp16},"
                    f"setting to {smp.state.cfg.fp16}"
                )
                args.fp16 = smp.state.cfg.fp16

512
        if args.fp16 or args.bf16:
513
514
515
516
517
            if self.fsdp is not None:
                raise ValueError(
                    "Mixed precision is currently not supported for FSDP."
                    "Please do not set arguments related to `mixed_precision`"
                )
518
            if args.half_precision_backend == "auto":
519
520
521
522
523
524
525
                if args.device == torch.device("cpu"):
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
526
                else:
527
528
529
                    if _is_native_cuda_amp_available:
                        args.half_precision_backend = "cuda_amp"
                    elif args.bf16:
530
531
532
                        raise ValueError("Tried to use `bf16` but native amp is not available")
                    else:
                        args.half_precision_backend = "apex"
533

534
            logger.info(f"Using {args.half_precision_backend} half precision backend")
535

536
        self.do_grad_scaling = False
537
538
        if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
            # deepspeed and SageMaker Model Parallel manage their own half precision
539
540
            if args.half_precision_backend == "cuda_amp":
                self.use_cuda_amp = True
541
542
                self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                self.do_grad_scaling = True
543
                if self.sharded_ddp is not None:
544
                    self.scaler = ShardedGradScaler()
545
546
547
548
                elif is_torch_tpu_available():
                    from torch_xla.amp import GradScaler

                    self.scaler = GradScaler()
549
550
                else:
                    self.scaler = torch.cuda.amp.GradScaler()
551
552
553
            elif args.half_precision_backend == "cpu_amp":
                self.use_cpu_amp = True
                self.amp_dtype = torch.bfloat16
554
555
556
            else:
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
559
560
561
                    )
                self.use_apex = True

562
563
564
565
566
567
568
569
570
571
572
573
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
574
575
576
577
578
579
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

580
581
582
583
584
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
585
        self.control = TrainerControl()
586
587
588
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
589
        self.hp_search_backend = None
590
        self.use_tune_checkpoints = False
591
        default_label_names = find_labels(self.model.__class__)
592
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
593
594
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

595
596
597
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

598
599
600
        # very last
        self._memory_tracker.stop_and_update_metrics()

Sylvain Gugger's avatar
Sylvain Gugger committed
601
602
    def add_callback(self, callback):
        """
603
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
604
605

        Args:
606
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
607
608
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
609
610
611
612
613
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
614
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
615

616
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
617
618

        Args:
619
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
620
621
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
622
623

        Returns:
624
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
625
626
627
628
629
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
630
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
631
632

        Args:
633
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
634
635
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
636
637
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
638

Sylvain Gugger's avatar
Sylvain Gugger committed
639
640
641
642
643
644
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

645
    def _set_signature_columns_if_needed(self):
646
647
648
649
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
650
651
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
652

653
654
655
656
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
657
        signature_columns = self._signature_columns
658
659

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
660
        if len(ignored_columns) > 0:
661
            dset_description = "" if description is None else f"in the {description} set"
662
663
664
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
665
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
666
                " you can safely ignore this message."
667
            )
668

669
        columns = [k for k in signature_columns if k in dataset.column_names]
670

671
672
673
674
675
676
677
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
678

679
680
681
682
683
684
685
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
686
        signature_columns = self._signature_columns
687
688
689
690
691
692
693
694
695
696

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

697
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
698
        if self.train_dataset is None or not has_length(self.train_dataset):
699
            return None
700

701
702
703
        generator = None
        if self.args.world_size <= 1 and _is_torch_generator_available:
            generator = torch.Generator()
704
705
706
707
708
709
710
711
712
713
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
714

715
716
        # Build the sampler.
        if self.args.group_by_length:
717
718
719
720
721
722
723
724
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
725
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
726
            if self.args.world_size <= 1:
727
                return LengthGroupedSampler(
728
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
729
                    dataset=self.train_dataset,
730
731
732
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
733
                )
734
735
            else:
                return DistributedLengthGroupedSampler(
736
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
737
                    dataset=self.train_dataset,
738
739
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
740
                    lengths=lengths,
741
                    model_input_name=model_input_name,
742
                    seed=seed,
743
744
745
                )

        else:
746
            if self.args.world_size <= 1:
747
748
                if _is_torch_generator_available:
                    return RandomSampler(self.train_dataset, generator=generator)
749
                return RandomSampler(self.train_dataset)
Sylvain Gugger's avatar
Sylvain Gugger committed
750
751
752
753
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
754
755
756
757
758
759
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
760
                    seed=seed,
761
                )
762
            else:
763
                return DistributedSampler(
764
765
766
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
767
                    seed=seed,
768
                )
769
770
771

    def get_train_dataloader(self) -> DataLoader:
        """
772
        Returns the training [`~torch.utils.data.DataLoader`].
773

774
775
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
776
777
778
779
780

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
781

782
        train_dataset = self.train_dataset
783
        data_collator = self.data_collator
784
785
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
786
787
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
788

789
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
790
791
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
792
                    train_dataset,
793
                    batch_size=self._train_batch_size,
794
795
796
797
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
798

799
800
            return DataLoader(
                train_dataset,
801
                batch_size=self.args.per_device_train_batch_size,
802
                collate_fn=data_collator,
803
804
805
806
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

807
808
809
        train_sampler = self._get_train_sampler()

        return DataLoader(
810
            train_dataset,
811
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
812
            sampler=train_sampler,
813
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
814
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
815
            num_workers=self.args.dataloader_num_workers,
816
            pin_memory=self.args.dataloader_pin_memory,
817
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
818
819
        )

820
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            elif self.args.local_rank != -1:
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
843
844
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
845
846
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
847
            )
Lysandre Debut's avatar
Lysandre Debut committed
848

Julien Chaumond's avatar
Julien Chaumond committed
849
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
850
        """
851
        Returns the evaluation [`~torch.utils.data.DataLoader`].
852

853
854
        Subclass and override this method if you want to inject some custom behavior.

855
        Args:
856
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
857
858
                If provided, will override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not accepted by
                the `model.forward()` method are automatically removed. It must implement `__len__`.
859
        """
Julien Chaumond's avatar
Julien Chaumond committed
860
861
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
862
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
863
        data_collator = self.data_collator
864

865
866
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
867
868
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
869

870
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
871
872
873
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
874
                    batch_size=self.args.per_device_eval_batch_size,
875
876
877
878
879
880
881
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
882
                collate_fn=data_collator,
883
884
885
886
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

887
        eval_sampler = self._get_eval_sampler(eval_dataset)
888

889
        return DataLoader(
890
            eval_dataset,
891
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
892
            batch_size=self.args.eval_batch_size,
893
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
894
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
895
            num_workers=self.args.dataloader_num_workers,
896
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
897
898
899
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
900
        """
901
        Returns the test [`~torch.utils.data.DataLoader`].
902

903
904
        Subclass and override this method if you want to inject some custom behavior.

905
        Args:
906
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
907
908
                The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
                method are automatically removed. It must implement `__len__`.
909
        """
910
911
        data_collator = self.data_collator

912
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
913
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
914
915
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
916

917
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
918
919
920
921
922
923
924
925
926
927
928
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
929
                collate_fn=data_collator,
930
931
932
933
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

934
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
935

936
937
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
938
            test_dataset,
939
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
940
            batch_size=self.args.eval_batch_size,
941
            collate_fn=data_collator,
942
            drop_last=self.args.dataloader_drop_last,
943
            num_workers=self.args.dataloader_num_workers,
944
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
945
        )
Lysandre Debut's avatar
Lysandre Debut committed
946

947
    def create_optimizer_and_scheduler(self, num_training_steps: int):
948
949
950
        """
        Setup the optimizer and the learning rate scheduler.

951
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
952
953
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
954
955
        """
        self.create_optimizer()
956
957
958
959
        self.create_scheduler(
            num_training_steps=num_training_steps,
            optimizer=self.optimizer.optimizer if is_sagemaker_mp_enabled() and smp.state.cfg.fp16 else self.optimizer,
        )
960
961
962
963
964

    def create_optimizer(self):
        """
        Setup the optimizer.

965
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
966
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
967
        """
968
969
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

970
        if self.optimizer is None:
971
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
972
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
973
974
            optimizer_grouped_parameters = [
                {
975
                    "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters],
976
977
978
                    "weight_decay": self.args.weight_decay,
                },
                {
979
                    "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters],
980
981
982
                    "weight_decay": 0.0,
                },
            ]
983
984
985

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

986
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
987
988
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
989
990
                    optim=optimizer_cls,
                    **optimizer_kwargs,
991
992
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
993
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
994
995
996
997
998
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

999
                    for module in opt_model.modules():
1000
1001
1002
                        if isinstance(module, nn.Embedding):
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
Sylvain Gugger's avatar
Sylvain Gugger committed
1003

Sylvain Gugger's avatar
Sylvain Gugger committed
1004
1005
1006
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1007
1008
        return self.optimizer

1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
        optimizer_kwargs = {"lr": args.learning_rate}
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
        elif args.optim == OptimizerNames.ADAMW_TORCH:
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1037
1038
1039
1040
1041
1042
1043
1044
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1045
1046
1047
1048
1049
1050
1051
1052
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1053
1054
1055
1056
1057
1058
1059
1060
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1061
1062
1063
1064
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1065
1066
1067
1068
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1069
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1070
        """
1071
1072
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1073
1074
1075
1076

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1077
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1078
1079
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1080
                optimizer=self.optimizer if optimizer is None else optimizer,
1081
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1082
                num_training_steps=num_training_steps,
1083
            )
1084
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1085

1086
    def num_examples(self, dataloader: DataLoader) -> int:
1087
        """
1088
1089
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1090
        """
1091
        try:
1092
1093
1094
1095
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1096
1097
1098
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1099

1100
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1101
        """HP search setup code"""
1102
1103
        self._trial = trial

1104
1105
        if self.hp_search_backend is None or trial is None:
            return
1106
1107
1108
1109
1110
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1111
1112
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1113
1114
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1115

1116
1117
        for key, value in params.items():
            if not hasattr(self.args, key):
1118
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1119
1120
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1121
                )
1122
                continue
1123
1124
1125
1126
1127
1128
1129
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)
1130
1131
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1132
1133
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1134
1135
        if self.args.deepspeed:
            # Rebuild the deepspeed config to reflect the updated training parameters
1136
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1137

1138
1139
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1140
1141
1142
1143
1144
1145

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
1146
        self.objective = self.compute_objective(metrics.copy())
1147
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1148
1149
            import optuna

1150
1151
            trial.report(self.objective, epoch)
            if trial.should_prune():
1152
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1153
1154
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1155
1156
            from ray import tune

1157
            if self.control.should_save:
1158
                self._tune_save_checkpoint()
1159
1160
            tune.report(objective=self.objective, **metrics)

1161
    def _tune_save_checkpoint(self):
1162
1163
        from ray import tune

1164
1165
        if not self.use_tune_checkpoints:
            return
1166
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1167
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1168
            self.save_model(output_dir, _internal_call=True)
1169
            if self.args.should_save:
1170
1171
1172
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1173

1174
    def call_model_init(self, trial=None):
1175
        model_init_argcount = number_of_arguments(self.model_init)
1176
1177
1178
1179
1180
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1181
1182
1183
1184
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1185
1186
1187

        return model

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            jit_inputs = []
            example_batch = next(iter(dataloader))
            for key in example_batch:
                example_tensor = torch.ones_like(example_batch[key])
                jit_inputs.append(example_tensor)
            jit_inputs = tuple(jit_inputs)
            try:
                jit_model = model.eval()
                with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
                    jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
                jit_model = torch.jit.freeze(jit_model)
                jit_model(**example_batch)
                model = jit_model
            except (RuntimeError, TypeError) as e:
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
                "Using IPEX but IPEX is not installed, please refer to"
                " https://github.com/intel/intel-extension-for-pytorch."
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
            model = ipex.optimize(model, dtype=dtype, level="O1")
        else:
            if not model.training:
                model.train()
            model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")

        return model

1230
    def _wrap_model(self, model, training=True, dataloader=None):
1231
1232
1233
1234
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

1235
1236
1237
        if self.args.jit_mode_eval:
            model = self.torch_jit_model_eval(model, dataloader, training)

Sylvain Gugger's avatar
Sylvain Gugger committed
1238
1239
1240
1241
1242
1243
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1244
1245
        # already initialized its own DDP and AMP
        if self.deepspeed:
1246
            return self.deepspeed
1247

1248
1249
1250
1251
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1252
1253
1254
1255
1256
1257
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
1258
            model = nn.DataParallel(model)
1259
1260
1261
1262
1263
1264
1265

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1266
1267
1268
1269
1270
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1271
                mixed_precision = self.args.fp16 or self.args.bf16
1272
1273
1274
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1275
1276
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1277
                self.model = model = FullyShardedDDP(
1278
1279
1280
1281
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1282
1283
                ).to(self.args.device)

1284
1285
1286
1287
1288
        # Distributed training using PyTorch FSDP
        if self.fsdp is not None:
            # PyTorch FSDP!
            from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
            from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
1289
            from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299

            if FSDPOption.OFFLOAD in self.args.fsdp:
                cpu_offload = CPUOffload(offload_params=True)
            else:
                cpu_offload = CPUOffload(offload_params=False)

            auto_wrap_policy = None
            if FSDPOption.AUTO_WRAP in self.args.fsdp:
                if self.args.fsdp_min_num_params > 0:
                    auto_wrap_policy = functools.partial(
1300
                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
                    )

            if type(model) != FSDP:
                # XXX: Breaking the self.model convention but I see no way around it for now.
                self.model = model = FSDP(
                    model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
                )
                if FSDPOption.OFFLOAD not in self.args.fsdp:
                    model.to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1311
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1312
1313
1314
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1315
        elif self.args.local_rank != -1:
1316
            kwargs = {}
1317
            if self.args.ddp_find_unused_parameters is not None:
1318
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1319
1320
1321
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1322
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1323
            else:
1324
1325
1326
1327
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1328
            model = nn.parallel.DistributedDataParallel(
1329
                model,
1330
1331
                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
1332
                **kwargs,
1333
1334
1335
1336
            )

        return model

1337
1338
    def train(
        self,
1339
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1340
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1341
        ignore_keys_for_eval: Optional[List[str]] = None,
1342
        **kwargs,
1343
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1344
1345
1346
1347
        """
        Main training entry point.

        Args:
1348
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1349
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1350
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1351
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1352
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1353
                The trial run or the hyperparameter dictionary for hyperparameter search.
1354
            ignore_keys_for_eval (`List[str]`, *optional*)
1355
1356
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1357
1358
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1359
        """
1360
1361
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1362
1363
1364
1365

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1366
1367
        args = self.args

1368
1369
        self.is_in_train = True

1370
1371
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1372
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1373
            self._move_model_to_device(self.model, args.device)
1374

1375
1376
1377
1378
1379
1380
1381
1382
1383
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1384
1385
1386
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

1387
        # Model re-init
1388
        model_reloaded = False
1389
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1390
            # Seed must be set before instantiating the model when using model_init.
1391
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1392
1393
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1394
1395
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1396

1397
        # Load potential model checkpoint
1398
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1399
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1400
            if resume_from_checkpoint is None:
1401
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1402

1403
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
1404
            self._load_from_checkpoint(resume_from_checkpoint)
1405

1406
1407
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1408
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1409
                self._move_model_to_device(self.model, args.device)
1410
1411
            self.model_wrapped = self.model

1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self._train_batch_size = batch_size
1426
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1427
        train_dataloader = self.get_train_dataloader()
1428
1429
1430
1431
1432

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1433
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1434
1435
1436
1437
1438

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1439
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1440
            num_examples = self.num_examples(train_dataloader)
1441
1442
1443
1444
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1445
                )
1446
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1447
1448
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1449
            else:
1450
1451
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1452
1453
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1454
            max_steps = args.max_steps
1455
1456
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1457
            num_update_steps_per_epoch = max_steps
1458
            num_examples = total_train_batch_size * args.max_steps
1459
            num_train_samples = args.max_steps * total_train_batch_size
1460
1461
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1462
1463
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1464
            )
Julien Chaumond's avatar
Julien Chaumond committed
1465

1466
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1467
1468
1469
1470
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1471
1472
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1473
1474
1475
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1476

1477
        delay_optimizer_creation = (
1478
1479
1480
1481
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1482
        )
1483
        if args.deepspeed:
1484
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
1485
1486
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
1487
1488
1489
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
1490
1491
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
1492
        elif not delay_optimizer_creation:
1493
1494
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1495
        self.state = TrainerState()
1496
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1497

1498
1499
1500
1501
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1502
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1503

1504
1505
1506
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1507
1508
1509
1510
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1511
1512
1513
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1514
1515
1516
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1517
1518
        # important: at this point:
        # self.model         is the Transformers Model
1519
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1520

Julien Chaumond's avatar
Julien Chaumond committed
1521
1522
        # Train!
        logger.info("***** Running training *****")
1523
1524
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
1525
        logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
1526
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
1527
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1528
        logger.info(f"  Total optimization steps = {max_steps}")
Julien Chaumond's avatar
Julien Chaumond committed
1529

1530
        self.state.epoch = 0
1531
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1532
1533
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1534
        steps_trained_progress_bar = None
1535

Julien Chaumond's avatar
Julien Chaumond committed
1536
        # Check if continuing training from a checkpoint
1537
        if resume_from_checkpoint is not None and os.path.isfile(
1538
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1539
        ):
1540
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1541
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1542
            if not args.ignore_data_skip:
1543
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1544
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1545
1546
            else:
                steps_trained_in_current_epoch = 0
1547
1548

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1549
1550
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1551
            if not args.ignore_data_skip:
1552
1553
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
1554
1555
                    "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
                    "flag to your launch command, but you will resume the training on data already seen by your model."
1556
                )
1557
1558
1559
                if self.is_local_process_zero() and not args.disable_tqdm:
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1560

Sylvain Gugger's avatar
Sylvain Gugger committed
1561
1562
1563
1564
1565
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1566
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
1567
1568
1569
1570
1571
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1572
1573
1574
1575
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1576
1577
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1578

1579
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1580
        tr_loss = torch.tensor(0.0).to(args.device)
1581
1582
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1583
        self._globalstep_last_logged = self.state.global_step
Julien Chaumond's avatar
Julien Chaumond committed
1584
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1585

1586
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1587

1588
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1589
        if not args.ignore_data_skip:
1590
            for epoch in range(epochs_trained):
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
                if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler:
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1603

1604
        for epoch in range(epochs_trained, num_train_epochs):
1605
1606
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1607
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1608
                train_dataloader.dataset.set_epoch(epoch)
1609

1610
            if is_torch_tpu_available():
1611
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1612
                epoch_iterator = parallel_loader
1613
            else:
1614
                epoch_iterator = train_dataloader
1615

1616
            # Reset the past mems state at the beginning of each epoch if necessary.
1617
            if args.past_index >= 0:
1618
1619
                self._past = None

1620
            steps_in_epoch = (
1621
1622
1623
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1624
            )
1625
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1626

1627
1628
1629
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1630
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1631
1632
1633
1634
1635
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1636
1637
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1638
1639
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1640
                    continue
1641
1642
1643
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1644

1645
1646
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1647

1648
                if (
1649
1650
1651
                    ((step + 1) % args.gradient_accumulation_steps != 0)
                    and args.local_rank != -1
                    and args._no_sync_in_gradient_accumulation
1652
                ):
1653
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1654
                    with model.no_sync():
1655
                        tr_loss_step = self.training_step(model, inputs)
1656
                else:
1657
1658
                    tr_loss_step = self.training_step(model, inputs)

1659
1660
1661
1662
1663
1664
1665
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1666
1667
1668
                else:
                    tr_loss += tr_loss_step

1669
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1670

1671
1672
1673
1674
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

1675
                if (step + 1) % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1676
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1677
                    steps_in_epoch <= args.gradient_accumulation_steps
1678
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1679
                ):
1680
                    # Gradient clipping
1681
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
1682
1683
                        # deepspeed does its own clipping

1684
                        if self.do_grad_scaling:
1685
1686
1687
1688
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1689
1690
1691
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1692
1693
1694
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1695
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1696
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1697
1698
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1699
                            model.clip_grad_norm_(args.max_grad_norm)
1700
1701
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
1702
                            nn.utils.clip_grad_norm_(
1703
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
1704
                                args.max_grad_norm,
1705
1706
1707
                            )

                    # Optimizer step
1708
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
1709
                    if self.deepspeed:
1710
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
1711
                    elif is_torch_tpu_available():
1712
1713
1714
1715
1716
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1717
                    elif self.do_grad_scaling:
1718
                        scale_before = self.scaler.get_scale()
1719
                        self.scaler.step(self.optimizer)
1720
                        self.scaler.update()
1721
1722
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1723
                    else:
1724
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
1725

1726
                    if optimizer_was_run and not self.deepspeed:
1727
1728
                        self.lr_scheduler.step()

Julien Chaumond's avatar
Julien Chaumond committed
1729
                    model.zero_grad()
1730
                    self.state.global_step += 1
1731
                    self.state.epoch = epoch + (step + 1) / steps_in_epoch
1732
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1733

1734
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
1735
1736
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1737

Sylvain Gugger's avatar
Sylvain Gugger committed
1738
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1739
                    break
1740
1741
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1742
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
1743
1744
1745
1746
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
1747

1748
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
1749
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
1750

1751
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
1752
1753
1754
1755
1756
1757
1758
1759
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1760
            if self.control.should_training_stop:
1761
                break
Julien Chaumond's avatar
Julien Chaumond committed
1762

1763
        if args.past_index and hasattr(self, "_past"):
1764
1765
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1766
1767

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
1768
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
1769
1770
1771
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
1772
            elif args.local_rank != -1:
1773
                dist.barrier()
1774
1775
            elif is_sagemaker_mp_enabled():
                smp.barrier()
1776

1777
            self._load_best_model()
1778

1779
1780
1781
1782
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

1783
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
1784
1785
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
1786
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
1787

1788
        self.is_in_train = False
1789

1790
1791
        self._memory_tracker.stop_and_update_metrics(metrics)

1792
1793
1794
1795
1796
        self.log(metrics)

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
1797

1798
1799
1800
1801
1802
1803
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):

        if model is None:
            model = self.model
        strict_load = is_sagemaker_mp_enabled()

1804
1805
1806
        if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
            os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        ):
1807
1808
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

1809
        logger.info(f"Loading model from {resume_from_checkpoint}.")
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823

        if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
            config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

        if self.args.deepspeed:
            # will be resumed in deepspeed_init
            pass
1824
        elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
1825
1826
1827
            # We load the model state dict on the CPU to avoid an OOM error.
            state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
            # If the model is on the GPU, it still works!
1828
1829
1830
            load_result = model.load_state_dict(state_dict, strict=strict_load)
            if not strict_load:
                self._issue_warnings_after_load(load_result)
1831
1832
            # release memory
            del state_dict
1833
1834
        else:
            # We load the sharded checkpoint
1835
1836
1837
            load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
            if not strict_load:
                self._issue_warnings_after_load(load_result)
1838
1839
1840
1841

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
1842
1843
        strict_load = is_sagemaker_mp_enabled()
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
1844
1845
        if os.path.exists(best_model_path):
            if self.deepspeed:
1846
1847
1848
1849
1850
1851

                if self.model_wrapped is not None:
                    # this removes the pre-hooks from the previous engine
                    self.model_wrapped.destroy()
                    self.model_wrapped = None

1852
                # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
1853
1854
1855
1856
1857
                deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
                    self,
                    num_training_steps=self.args.max_steps,
                    resume_from_checkpoint=self.state.best_model_checkpoint,
                )
1858
1859
1860
1861
1862
1863
1864
1865
1866
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
                self.optimizer = optimizer
                self.lr_scheduler = lr_scheduler
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(best_model_path, map_location="cpu")
                # If the model is on the GPU, it still works!
1867
1868
1869
                load_result = model.load_state_dict(state_dict, strict=strict_load)
                if not strict_load:
                    self._issue_warnings_after_load(load_result)
1870
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
1871
1872
1873
            load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
            if not strict_load:
                self._issue_warnings_after_load(load_result)
1874
1875
1876
1877
1878
1879
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

1880
    def _issue_warnings_after_load(self, load_result):
1881
1882

        if len(load_result.missing_keys) != 0:
1883
1884
1885
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
1886
1887
                self.model.tie_weights()
            else:
1888
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
1889
        if len(load_result.unexpected_keys) != 0:
1890
1891
1892
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
1893

1894
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
1895
        if self.control.should_log:
1896
1897
1898
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
1899
            logs: Dict[str, float] = {}
1900
1901
1902
1903

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

1904
1905
1906
            # reset tr_loss to zero
            tr_loss -= tr_loss

1907
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
1908
            logs["learning_rate"] = self._get_learning_rate()
1909

1910
            self._total_loss_scalar += tr_loss_scalar
1911
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
1912
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
1913
1914
1915
1916
1917

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
1918
            metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
Sylvain Gugger's avatar
Sylvain Gugger committed
1919
            self._report_to_hp_search(trial, epoch, metrics)
1920

Sylvain Gugger's avatar
Sylvain Gugger committed
1921
1922
1923
1924
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

1925
1926
1927
1928
1929
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

1930
1931
1932
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
1933
            if not os.path.isfile(rng_file):
1934
                logger.info(
1935
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
1936
1937
1938
1939
1940
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
1941
            if not os.path.isfile(rng_file):
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
            if self.args.local_rank != -1:
                torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
            else:
1956
1957
1958
                try:
                    torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
                except Exception as e:
1959
                    logger.info(
1960
1961
1962
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
1963
1964
1965
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
1966
    def _save_checkpoint(self, model, trial, metrics=None):
1967
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
1968
        # want to save except FullyShardedDDP.
1969
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
1970

1971
        # Save model checkpoint
1972
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1973

1974
        if self.hp_search_backend is not None and trial is not None:
1975
1976
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
1977
            elif self.hp_search_backend == HPSearchBackend.RAY:
1978
1979
1980
                from ray import tune

                run_id = tune.get_trial_id()
1981
1982
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
1983
1984
1985
1986
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
1987
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
1988
            run_dir = os.path.join(self.args.output_dir, run_name)
1989
        else:
1990
            run_dir = self.args.output_dir
1991
            self.store_flos()
1992

1993
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
1994
        self.save_model(output_dir, _internal_call=True)
1995
        if self.deepspeed:
1996
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
1997
            # config `stage3_gather_16bit_weights_on_model_save` is True
1998
            self.deepspeed.save_checkpoint(output_dir)
1999
2000

        # Save optimizer and scheduler
2001
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2002
            self.optimizer.consolidate_state_dict()
2003

2004
2005
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2006
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2007
            with warnings.catch_warnings(record=True) as caught_warnings:
2008
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2009
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2010
        elif is_sagemaker_mp_enabled():
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2026
        elif self.args.should_save and not self.deepspeed:
2027
            # deepspeed.save_checkpoint above saves model/optim/sched
2028
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2029
            with warnings.catch_warnings(record=True) as caught_warnings:
2030
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2031
            reissue_pt_warnings(caught_warnings)
2032
            if self.do_grad_scaling:
2033
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2034
2035

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2036
        if metrics is not None and self.args.metric_for_best_model is not None:
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2052
        if self.args.should_save:
2053
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2054

2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.local_rank == -1:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2071
2072
2073
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2074

2075
        if self.args.world_size <= 1:
2076
2077
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2078
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2079

2080
2081
2082
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2083
        # Maybe delete some older checkpoints.
2084
        if self.args.should_save:
2085
2086
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2087
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2088
        """If optimizer and scheduler states exist, load them."""
2089
        if checkpoint is None:
2090
2091
            return

2092
        if self.deepspeed:
2093
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2094
2095
            return

2096
2097
2098
2099
2100
2101
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2102
2103
2104
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2105
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2106
                with warnings.catch_warnings(record=True) as caught_warnings:
2107
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2108
2109
2110
2111
2112
2113
2114
2115
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2116
                map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
2117
2118
2119
                if is_sagemaker_mp_enabled():

                    def opt_load_hook(mod, opt):
2120
2121
2122
                        opt.load_state_dict(
                            smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True), gather_if_shard=False
                        )
2123
2124
2125
2126
2127
2128

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
                    self.optimizer.load_state_dict(
                        torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2129
                with warnings.catch_warnings(record=True) as caught_warnings:
2130
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2131
                reissue_pt_warnings(caught_warnings)
2132
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2133
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2134

2135
2136
2137
2138
2139
2140
2141
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2142
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2143
        **kwargs,
2144
2145
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2146
2147
2148
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2149

2150
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2151

Sylvain Gugger's avatar
Sylvain Gugger committed
2152
2153
2154
2155
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2156
2157

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2158

2159
        Args:
2160
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2161
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2162
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2163
2164
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2165
2166
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2167
            n_trials (`int`, *optional*, defaults to 100):
2168
                The number of trial runs to test.
2169
            direction(`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2170
2171
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2172
            backend(`str` or [`~training_utils.HPSearchBackend`], *optional*):
2173
2174
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2175
            kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
2176
2177
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2178

Sylvain Gugger's avatar
Sylvain Gugger committed
2179
2180
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2181
2182
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2183
2184

        Returns:
2185
            [`trainer_utils.BestRun`]: All the information about the best run.
2186
2187
2188
2189
2190
2191
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2192
2193
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2194
                    "To install sigopt run `pip install sigopt`."
2195
2196
2197
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2198
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2199
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2200
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2201
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2202
            )
2203
2204
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2205
2206
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2207
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2208
2209
2210
2211
2212
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2213
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2214
        self.hp_name = hp_name
2215
2216
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2217
2218
2219
2220
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2221
            HPSearchBackend.WANDB: run_hp_search_wandb,
2222
2223
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2224
2225
2226
2227

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2228
    def log(self, logs: Dict[str, float]) -> None:
2229
        """
2230
        Log `logs` on the various objects watching training.
2231
2232
2233
2234

        Subclass and override this method to inject custom behavior.

        Args:
2235
            logs (`Dict[str, float]`):
2236
2237
                The values to log.
        """
2238
        if self.state.epoch is not None:
2239
            logs["epoch"] = round(self.state.epoch, 2)
2240

2241
2242
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2243
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2244

2245
2246
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2247
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2248
        """
2249
2250
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
            kwargs = dict(device=self.args.device)
            if self.deepspeed and data.dtype != torch.int64:
                # NLP models inputs are int64 and those get adjusted to the right dtype of the
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
                kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2263
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2264
        """
2265
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2266
2267
        handling potential state.
        """
2268
        inputs = self._prepare_input(inputs)
2269
2270
2271
2272
2273
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2274
2275
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2276

2277
2278
        return inputs

2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
        return ContextManagers(
            [
                self.torchdynamo_smart_context_manager(),
                self.autocast_smart_context_manager(),
            ]
        )

    def torchdynamo_smart_context_manager(self):
        """
        A helper wrapper that creates an appropriate context manager for `torchdynamo`.
        """
        ctx_manager = contextlib.nullcontext()
        if is_torchdynamo_available():
            import torchdynamo
            from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

            if self.args.torchdynamo == "eager":
                ctx_manager = torchdynamo.optimize("eager")
            elif self.args.torchdynamo == "nvfuser":
                ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
        return ctx_manager

2305
2306
    def autocast_smart_context_manager(self):
        """
2307
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2308
2309
        arguments, depending on the situation.
        """
2310
        if self.use_cuda_amp or self.use_cpu_amp:
2311
            if version.parse(torch.__version__) >= version.parse("1.10"):
2312
2313
2314
2315
2316
                ctx_manager = (
                    torch.cpu.amp.autocast(dtype=self.amp_dtype)
                    if self.use_cpu_amp
                    else torch.cuda.amp.autocast(dtype=self.amp_dtype)
                )
2317
            else:
2318
                ctx_manager = torch.cuda.amp.autocast()
2319
2320
2321
2322
2323
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2324
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2325
        """
2326
        Perform a training step on a batch of inputs.
2327
2328
2329
2330

        Subclass and override to inject custom behavior.

        Args:
2331
            model (`nn.Module`):
2332
                The model to train.
2333
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2334
2335
2336
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2337
                argument `labels`. Check your model's documentation for all accepted arguments.
2338
2339

        Return:
2340
            `torch.Tensor`: The tensor with training loss on this batch.
2341
2342
        """
        model.train()
2343
        inputs = self._prepare_inputs(inputs)
2344

Sylvain Gugger's avatar
Sylvain Gugger committed
2345
        if is_sagemaker_mp_enabled():
2346
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2347
2348
            return loss_mb.reduce_mean().detach().to(self.args.device)

2349
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2350
            loss = self.compute_loss(model, inputs)
2351

Julien Chaumond's avatar
Julien Chaumond committed
2352
2353
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2354

2355
2356
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
2357
2358
            loss = loss / self.args.gradient_accumulation_steps

2359
        if self.do_grad_scaling:
2360
            self.scaler.scale(loss).backward()
2361
        elif self.use_apex:
2362
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2363
                scaled_loss.backward()
2364
        elif self.deepspeed:
2365
2366
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2367
2368
2369
        else:
            loss.backward()

2370
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
2371

2372
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2373
2374
2375
2376
2377
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2378
2379
2380
2381
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2382
2383
        outputs = model(**inputs)
        # Save past state if it exists
2384
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2385
2386
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2387

2388
        if labels is not None:
2389
2390
2391
2392
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2393
2394
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2395
2396
2397
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2398

2399
2400
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2401
2402
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2403
        """
2404
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2405

2406
2407
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2408
        Whether or not this process is the global main process (when training in a distributed fashion on several
2409
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2410
        """
2411
2412
2413
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2414
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2415
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2416
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2417

Sylvain Gugger's avatar
Sylvain Gugger committed
2418
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2419
        """
2420
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2421

2422
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2423
        """
2424
2425
2426
2427

        if output_dir is None:
            output_dir = self.args.output_dir

2428
        if is_torch_tpu_available():
2429
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2430
2431
2432
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
            state_dict = self.model_wrapped.state_dict()
2433
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2434
                self._save(output_dir, state_dict=state_dict)
2435
        elif (
2436
2437
2438
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2439
2440
        ):
            state_dict = self.model.state_dict()
2441

2442
            if self.args.should_save:
2443
                self._save(output_dir, state_dict=state_dict)
2444
2445
2446
        elif self.deepspeed:

            # this takes care of everything as long as we aren't under zero3
2447
            if self.args.should_save:
2448
2449
2450
2451
2452
2453
2454
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2455
                if self.args.should_save:
2456
2457
2458
2459
2460
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2461
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2462
2463
                # if false it will not be saved.
                # This must be called on all ranks
2464
                if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
2465
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2466
2467
2468
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2469
2470
                    )
                    self.deepspeed.save_checkpoint(output_dir)
2471

2472
        elif self.args.should_save:
2473
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2474

Sylvain Gugger's avatar
Sylvain Gugger committed
2475
2476
2477
2478
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2479
2480
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2481
        logger.info(f"Saving model checkpoint to {output_dir}")
2482
2483
2484

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2485
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2486
2487
2488
2489

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2490
        if not isinstance(self.model, PreTrainedModel):
2491
2492
2493
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2494
                    is_main_process=self.args.should_save,
2495
2496
2497
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2498
2499
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2500
2501
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2502
        else:
2503
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2504
        if self.tokenizer is not None and self.args.should_save:
2505
            self.tokenizer.save_pretrained(output_dir)
2506

2507
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2508
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2509
2510
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2511
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2512
2513
2514
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2515
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2516
2517
2518
                if state_dict is None:
                    state_dict = self.model.state_dict()
                unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
2519
2520
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2521
2522
                if state_dict is None:
                    state_dict = self.model.state_dict()
2523
                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2524
        else:
2525
            self.model.save_pretrained(output_dir, state_dict=state_dict)
2526
        if self.tokenizer is not None:
2527
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2528
2529

        # Good practice: save your training arguments together with the trained model
2530
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2531

2532
    def store_flos(self):
2533
        # Storing the number of floating-point operations that went into the model
2534
        if self.args.local_rank != -1:
2535
2536
2537
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2538
2539
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2540
            self.state.total_flos += self.current_flos
2541
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2542

2543
2544
2545
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2546
2547
        ordering_and_checkpoint_path = []

2548
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2549
2550
2551
2552
2553

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2554
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2555
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2556
2557
2558
2559
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2560
2561
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2562
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2563
2564
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2565
2566
        return checkpoints_sorted

2567
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2568
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2569
2570
2571
            return

        # Check if we should delete older checkpoint(s)
2572
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2573
2574
2575
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2576
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2587
2588
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2589
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
Julien Chaumond's avatar
Julien Chaumond committed
2590
2591
            shutil.rmtree(checkpoint)

2592
    def evaluate(
2593
2594
2595
2596
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2597
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2598
        """
2599
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2600

Sylvain Gugger's avatar
Sylvain Gugger committed
2601
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2602
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2603

2604
2605
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2606
        Args:
2607
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2608
2609
2610
                Pass a dataset if you wish to override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not
                accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
                method.
2611
            ignore_keys (`Lst[str]`, *optional*):
2612
2613
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2614
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2615
2616
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2617

Julien Chaumond's avatar
Julien Chaumond committed
2618
        Returns:
2619
2620
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2621
        """
2622
2623
2624
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2625
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2626
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2627

2628
2629
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2630
2631
2632
2633
2634
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
2635
            ignore_keys=ignore_keys,
2636
            metric_key_prefix=metric_key_prefix,
2637
        )
Lysandre Debut's avatar
Lysandre Debut committed
2638

2639
2640
2641
2642
2643
2644
2645
2646
2647
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2648

2649
        self.log(output.metrics)
2650

2651
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
2652
2653
2654
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
2655
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
2656
2657
2658

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
2659
2660
        return output.metrics

2661
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
2662
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
2663
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2664
        """
2665
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2666

Sylvain Gugger's avatar
Sylvain Gugger committed
2667
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
2668
        will also return metrics, like in `evaluate()`.
2669
2670

        Args:
2671
2672
2673
2674
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
            ignore_keys (`Lst[str]`, *optional*):
2675
2676
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2677
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
2678
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
2679
                "test_bleu" if the prefix is "test" (default)
2680

2681
2682
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
2683
2684
2685
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
2686

2687
        </Tip>
2688

2689
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
2690

2691
2692
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
2693
2694
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
2695
        """
2696
2697
2698
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2699
        test_dataloader = self.get_test_dataloader(test_dataset)
2700
        start_time = time.time()
2701

2702
2703
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2704
2705
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
2706
2707
2708
2709
2710
2711
2712
2713
2714
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2715
2716
2717

        self._memory_tracker.stop_and_update_metrics(output.metrics)

2718
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
2719

2720
    def evaluation_loop(
2721
2722
2723
2724
2725
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
2726
        metric_key_prefix: str = "eval",
2727
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2728
        """
2729
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
2730
2731
2732

        Works both with or without labels.
        """
2733
2734
2735
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
2736

2737
        # if eval is called w/o train init deepspeed here
2738
        if args.deepspeed and not self.deepspeed:
2739
2740
2741

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
2742
2743
2744
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
2745
2746
2747
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
2748

2749
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
2750

2751
2752
2753
2754
2755
2756
2757
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
2758

2759
        batch_size = self.args.eval_batch_size
2760

2761
        logger.info(f"***** Running {description} *****")
2762
        if has_length(dataloader):
2763
2764
2765
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
2766
        logger.info(f"  Batch size = {batch_size}")
2767

Julien Chaumond's avatar
Julien Chaumond committed
2768
2769
        model.eval()

2770
2771
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
2772
        eval_dataset = getattr(dataloader, "dataset", None)
2773

2774
        if is_torch_tpu_available():
2775
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
2776

2777
        if args.past_index >= 0:
2778
            self._past = None
2779

2780
2781
2782
2783
2784
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
2785
2786
        inputs_host = None

2787
2788
2789
2790
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
2791
        all_inputs = None
2792
2793
2794
2795
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
2796
        for step, inputs in enumerate(dataloader):
2797
2798
2799
2800
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
2801
2802
2803
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
2804
2805

            # Prediction step
2806
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
2807
            inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
2808

2809
2810
2811
            if is_torch_tpu_available():
                xm.mark_step()

2812
            # Update containers on host
2813
            if loss is not None:
2814
                losses = self._nested_gather(loss.repeat(batch_size))
2815
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
2816
            if labels is not None:
2817
2818
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
2819
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
2820
2821
2822
2823
2824
2825
2826
2827
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
2828
2829
2830
2831
2832
2833
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
2834
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2835

2836
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
2837
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
2838
2839
2840
2841
2842
2843
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
2844
2845
2846
2847
2848
2849
2850
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
2851
2852
2853
2854
2855
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
2856
2857

                # Set back to None to begin a new accumulation
2858
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
2859

2860
        if args.past_index and hasattr(self, "_past"):
2861
2862
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2863

2864
        # Gather all remaining tensors and put them back on the CPU
2865
2866
2867
2868
2869
2870
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
2871
2872
2873
2874
2875
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
2876
2877
2878
2879
2880
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
2881
        if has_length(eval_dataset):
2882
            num_samples = len(eval_dataset)
2883
2884
2885
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
2886
2887
            num_samples = eval_dataset.num_examples
        else:
2888
2889
2890
2891
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
2892
2893
2894
2895
2896
2897
2898
2899
2900

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
2901
2902
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
2903
2904
2905

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
2906
2907
2908
2909
2910
2911
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
2912
2913
        else:
            metrics = {}
2914

2915
2916
2917
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

2918
2919
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
2920

2921
        # Prefix all keys with metric_key_prefix + '_'
2922
        for key in list(metrics.keys()):
2923
2924
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
2925

2926
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
2927

2928
    def _nested_gather(self, tensors, name=None):
2929
2930
2931
2932
2933
2934
2935
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
2936
2937
            if name is None:
                name = "nested_gather"
2938
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
2939
2940
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
2941
2942
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)
2943
        return tensors
2944

2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
        if tensor.shape[1] == max_size:
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
2977

2978
    def prediction_step(
2979
2980
2981
2982
2983
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
2984
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
2985
        """
Stas Bekman's avatar
Stas Bekman committed
2986
        Perform an evaluation step on `model` using `inputs`.
2987
2988
2989
2990

        Subclass and override to inject custom behavior.

        Args:
2991
            model (`nn.Module`):
2992
                The model to evaluate.
2993
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2994
2995
2996
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2997
2998
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
2999
                Whether or not to return the loss only.
3000
            ignore_keys (`Lst[str]`, *optional*):
3001
3002
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3003
3004

        Return:
3005
3006
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3007
        """
3008
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
3009
        inputs = self._prepare_inputs(inputs)
3010
3011
3012
3013
3014
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3015

3016
3017
3018
3019
3020
3021
3022
3023
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3024
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3037
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3038
3039
3040
3041
3042
3043
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3044
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3045
                if has_labels:
3046
                    with self.compute_loss_context_manager():
3047
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3048
                    loss = loss.mean().detach()
3049

Sylvain Gugger's avatar
Sylvain Gugger committed
3050
3051
3052
3053
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3054
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3055
                    loss = None
3056
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3057
3058
3059
3060
3061
3062
3063
3064
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3065
3066
3067
3068

        if prediction_loss_only:
            return (loss, None, None)

3069
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3070
3071
3072
3073
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3074
3075
3076

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3077
3078
3079
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3080
3081

        Args:
3082
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3083
3084
3085
                The inputs and targets of the model.

        Returns:
3086
            `int`: The number of floating-point operations.
3087
        """
3088
3089
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3090
3091
        else:
            return 0
3092

3093
    def init_git_repo(self, at_init: bool = False):
3094
        """
3095
        Initializes a git repo in `self.args.hub_model_id`.
3096
3097
3098
3099
3100
3101

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3102
        """
3103
        if not self.is_world_process_zero():
3104
            return
3105
3106
        use_auth_token = True if self.args.hub_token is None else self.args.hub_token
        if self.args.hub_model_id is None:
3107
            repo_name = Path(self.args.output_dir).absolute().name
3108
3109
        else:
            repo_name = self.args.hub_model_id
3110
3111
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3112

3113
3114
3115
3116
3117
        try:
            self.repo = Repository(
                self.args.output_dir,
                clone_from=repo_name,
                use_auth_token=use_auth_token,
3118
                private=self.args.hub_private_repo,
3119
3120
            )
        except EnvironmentError:
3121
            if self.args.overwrite_output_dir and at_init:
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
                self.repo = Repository(
                    self.args.output_dir,
                    clone_from=repo_name,
                    use_auth_token=use_auth_token,
                )
            else:
                raise

        self.repo.git_pull()
3133
3134

        # By default, ignore the checkpoint folders
3135
3136
3137
3138
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3139
3140
3141
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3142
3143
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3144
3145
3146
3147
3148
3149
3150
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
        tags: Optional[str] = None,
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3151
        tasks: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3152
3153
3154
3155
        dataset_tags: Optional[Union[str, List[str]]] = None,
        dataset: Optional[Union[str, List[str]]] = None,
        dataset_args: Optional[Union[str, List[str]]] = None,
    ):
3156
3157
3158
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3159
3160
3161
3162
3163
3164
3165
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3166
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3167
3168
3169
3170
3171
3172
3173
3174
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3209
3210
3211
            _, self.push_in_progress = self.repo.push_to_hub(
                commit_message=commit_message, blocking=False, auto_lfs_prune=True
            )
3212
3213
3214
3215
3216
3217
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3218
        """
3219
        Upload *self.model* and *self.tokenizer* to the 馃 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3220
3221

        Parameters:
3222
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3223
                Message to commit while pushing.
3224
3225
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3226
            kwargs:
3227
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3228
3229

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3230
3231
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3232
        """
3233
3234
3235
3236
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3237

3238
        if self.args.should_save:
3239
3240
3241
3242
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3243

3244
3245
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3246
        self.save_model(_internal_call=True)
3247
3248
3249
3250
3251

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3252
3253
3254
3255
3256
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3257
3258
3259
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3260
3261
3262
3263
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3264
3265
3266
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3267
3268
3269
3270
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3271

3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        """
3285
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3286
3287
3288

        Works both with or without labels.
        """
3289
3290
        args = self.args

3291
3292
3293
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3294
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3295
3296

        # if eval is called w/o train init deepspeed here
3297
        if args.deepspeed and not self.deepspeed:
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

3310
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3311

3312
3313
3314
3315
3316
3317
3318
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3319
3320
3321
3322
3323
3324
3325
3326
3327

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3328
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3329

3330
        world_size = max(1, args.world_size)
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3341
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3342
3343
3344
3345

        model.eval()

        if is_torch_tpu_available():
3346
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3347

3348
        if args.past_index >= 0:
3349
3350
3351
3352
3353
3354
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3355
3356
            inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None

3357
3358
3359
3360
3361
3362
3363
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3364
3365
3366
3367
3368
3369
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3370
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3371
3372

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3373
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3374
3375
3376
3377
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3378
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3379
3380

                # Set back to None to begin a new accumulation
3381
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3382

3383
        if args.past_index and hasattr(self, "_past"):
3384
3385
3386
3387
3388
3389
3390
3391
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3392
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3393
3394
3395
3396

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3397
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3398
3399

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3400
3401
3402
3403
3404
3405
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)