trainer.py 184 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34

35
36
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
37

38
# Integrations must be imported before ML frameworks:
39
40
# isort: off
from .integrations import (
41
    default_hp_search_backend,
42
    get_reporting_integration_callbacks,
43
    hp_params,
44
    is_fairscale_available,
45
    is_optuna_available,
46
    is_ray_tune_available,
47
    is_sigopt_available,
48
    is_wandb_available,
49
50
    run_hp_search_optuna,
    run_hp_search_ray,
51
    run_hp_search_sigopt,
52
    run_hp_search_wandb,
53
)
54

55
56
# isort: on

57
58
import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
59
import torch.distributed as dist
60
from huggingface_hub import Repository, create_repo
61
62
from packaging import version
from torch import nn
63
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
64
65
from torch.utils.data.distributed import DistributedSampler

66
67
from . import __version__
from .configuration_utils import PretrainedConfig
68
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
69
from .debug_utils import DebugOption, DebugUnderflowOverflow
70
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
71
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
72
from .modelcard import TrainingSummary
73
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
74
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
75
from .optimization import Adafactor, get_scheduler
76
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
77
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
78
79
80
81
82
83
84
85
86
87
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
88
    DistributedLengthGroupedSampler,
89
    DistributedSamplerWithLoop,
90
    DistributedTensorGatherer,
91
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
92
    LabelSmoother,
93
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
94
    SequentialDistributedSampler,
95
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
    distributed_broadcast_scalars,
    distributed_concat,
98
    find_batch_size,
99
    get_model_param_count,
100
    get_module_class_from_name,
101
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
102
103
104
    nested_concat,
    nested_detach,
    nested_numpify,
105
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
108
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
109
110
111
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
112
    EvalLoopOutput,
113
    EvalPrediction,
114
    FSDPOption,
115
    HPSearchBackend,
116
117
    HubStrategy,
    IntervalStrategy,
118
    PredictionOutput,
119
    RemoveColumnsCollator,
120
    ShardedDDPOption,
121
    TrainerMemoryTracker,
122
123
124
    TrainOutput,
    default_compute_objective,
    default_hp_space,
125
    denumpify_detensorize,
126
    enable_full_determinism,
127
    find_executable_batch_size,
128
    get_last_checkpoint,
129
    has_length,
130
    number_of_arguments,
131
    seed_worker,
132
    set_seed,
133
    speed_metrics,
134
)
135
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
136
137
from .utils import (
    CONFIG_NAME,
138
139
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
140
    WEIGHTS_INDEX_NAME,
141
    WEIGHTS_NAME,
142
    can_return_loss,
143
    find_labels,
144
    get_full_repo_name,
145
    is_accelerate_available,
146
147
148
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
149
    is_ipex_available,
150
    is_safetensors_available,
151
152
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
153
    is_torch_compile_available,
154
    is_torch_neuroncore_available,
155
156
    is_torch_tpu_available,
    logging,
157
    strtobool,
158
)
159
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
160
161


162
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
163

Sylvain Gugger's avatar
Sylvain Gugger committed
164
DEFAULT_CALLBACKS = [DefaultFlowCallback]
165
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
166

167
168
169
170
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
171

172
173
if is_apex_available():
    from apex import amp
174

175
176
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
177

178
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
179
180
181
182
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

183
if is_fairscale_available():
184
    dep_version_check("fairscale")
185
    import fairscale
186
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
187
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
188
    from fairscale.nn.wrap import auto_wrap
189
190
191
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

192

Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
195
196
197
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
198
199

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
200
201
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
202

203

204
205
206
207
if is_safetensors_available():
    import safetensors.torch


208
209
210
211
212
213
214
215
skip_first_batches = None
if is_accelerate_available():
    from accelerate import __version__ as accelerate_version

    if version.parse(accelerate_version) >= version.parse("0.16"):
        from accelerate import skip_first_batches


216
217
218
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
219
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
220
221


222
223
224
225
226
227
228
229
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
230
231
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
232
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
233
234

    Args:
235
236
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
237

238
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
239

Sylvain Gugger's avatar
Sylvain Gugger committed
240
241
242
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.
243
244
245
246

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
249
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
250
251
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
252
253
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
254
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
255
256
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
257
258
259
260
261
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
262
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
263
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
264
265
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
266
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
267
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
268
269
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
270
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
271
272
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
273

274
275
276
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
277
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
278
279
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
280
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
281
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
282
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
283

284
285
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
286
287
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
288
289
290
291
292
293
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
294

295
296
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
297
298
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
299
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
300
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
301
302
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
303
304
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
305
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
306
307
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
308
309
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
310

Julien Chaumond's avatar
Julien Chaumond committed
311
312
    """

313
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
314

Julien Chaumond's avatar
Julien Chaumond committed
315
316
    def __init__(
        self,
317
        model: Union[PreTrainedModel, nn.Module] = None,
318
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
319
320
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
321
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
322
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
323
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
324
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
325
        callbacks: Optional[List[TrainerCallback]] = None,
326
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
327
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
328
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
329
        if args is None:
330
331
332
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
        self.args = args
        # Seed must be set before instantiating the model when using model
335
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
336
        self.hp_name = None
337
        self.deepspeed = None
338
        self.is_in_train = False
339

340
341
342
343
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

344
        # set the correct log level depending on the node
345
        log_level = args.get_process_log_level()
346
347
        logging.set_verbosity(log_level)

348
349
350
        # force device and distributed setup init explicitly
        args._setup_devices

351
352
353
354
355
356
357
358
359
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
360
361
362
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
363
364
365
                    FutureWarning,
                )
            self.model_init = model_init
366

367
368
369
370
371
372
373
374
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

375
376
377
378
379
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

380
381
382
383
384
385
386
387
388
389
390
391
392
        if (
            getattr(model, "hf_device_map", None) is not None
            and len([device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]) > 1
            and not self.is_model_parallel
        ):
            self.is_model_parallel = True

            # warn users
            logger.info(
                "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
                " to `True` to avoid any unexpected behavior such as device placement mismatching."
            )

393
394
        # At this stage the model is already loaded
        if getattr(model, "is_loaded_in_8bit", False):
395
396
397
398
399
400
401
402
403
404
405
406
            if getattr(model, "_is_int8_training_enabled", False):
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )
407

408
409
410
411
412
413
414
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
415
416
417
418
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
419
            if args.parallel_mode != ParallelMode.DISTRIBUTED:
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

435
436
437
438
439
440
        self.fsdp = None
        if len(args.fsdp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
441
            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
442
443
                raise ValueError("Using fsdp only works in distributed training.")

444
445
446
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
447
            # below is the current alternative.
448
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
449
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
450

451
            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
452
453
454
455
456

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
457
458
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
459

460
461
462
463
            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
            if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch:
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

Seung-Moo Yang's avatar
Seung-Moo Yang committed
464
465
466
            self.forward_prefetch = False
            if self.args.fsdp_config.get("forward_prefect", False):
                self.forward_prefetch = True
467

468
469
470
471
            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

472
        # one place to sort out whether to place the model on device or not
473
474
475
476
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
477
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
478
        # 4. Sharded DDP - same as MP
479
        # 5. FSDP - same as MP
480
        self.place_model_on_device = args.place_model_on_device
481
482
        if (
            self.is_model_parallel
483
            or args.deepspeed
484
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
485
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
486
            or (self.fsdp is not None)
487
        ):
488
489
            self.place_model_on_device = False

490
491
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
492
493
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
494
        self.tokenizer = tokenizer
495

496
        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
Sylvain Gugger's avatar
Sylvain Gugger committed
497
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
498
499
500

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
501
            self.args._n_gpu = 1
502
503
504
505
506

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
507
        self.compute_metrics = compute_metrics
508
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
509
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
510
511
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
512
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
513
514
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
530
        if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
531
532
533
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
534
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
535
536
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
537
538
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
539
540
541
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
542
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
543

544
545
546
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

547
548
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
549
            self.init_git_repo(at_init=True)
550
551
552
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
553
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
554
555
                dist.barrier()

556
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
557
            os.makedirs(self.args.output_dir, exist_ok=True)
558

559
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
560
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
561

562
563
564
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

565
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
566
567
568
569
            raise ValueError(
                "The train_dataset does not implement __len__, max_steps has to be specified. "
                "The number of steps needs to be known in advance for the learning rate scheduler."
            )
570

571
572
573
574
575
576
577
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

578
        self._signature_columns = None
579

580
581
        # Mixed precision setup
        self.use_apex = False
582
583
        self.use_cuda_amp = False
        self.use_cpu_amp = False
584

585
586
587
588
589
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
607

608
609
        if args.fp16 or args.bf16:
            if args.half_precision_backend == "auto":
610
                if args.device == torch.device("cpu"):
611
612
613
614
615
616
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
617
                else:
618
                    args.half_precision_backend = "cuda_amp"
619

620
            logger.info(f"Using {args.half_precision_backend} half precision backend")
621

622
        self.do_grad_scaling = False
623
        if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
624
            # deepspeed and SageMaker Model Parallel manage their own half precision
625
626
            if args.half_precision_backend == "cuda_amp":
                self.use_cuda_amp = True
627
                self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
628
629
630
631
632
633
                #  bf16 does not need grad scaling
                self.do_grad_scaling = self.amp_dtype == torch.float16
                if self.do_grad_scaling:
                    if self.sharded_ddp is not None:
                        self.scaler = ShardedGradScaler()
                    elif self.fsdp is not None:
634
635
636
                        from torch.distributed.fsdp.sharded_grad_scaler import (
                            ShardedGradScaler as FSDPShardedGradScaler,
                        )
637

638
                        self.scaler = FSDPShardedGradScaler()
639
640
                    elif is_torch_tpu_available():
                        from torch_xla.amp import GradScaler
641

642
643
644
                        self.scaler = GradScaler()
                    else:
                        self.scaler = torch.cuda.amp.GradScaler()
645
646
647
            elif args.half_precision_backend == "cpu_amp":
                self.use_cpu_amp = True
                self.amp_dtype = torch.bfloat16
648
649
650
            else:
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
651
652
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
653
654
655
                    )
                self.use_apex = True

656
657
658
659
660
661
662
663
664
665
666
667
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
668
669
670
671
672
673
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

674
675
676
677
678
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
679
        self.control = TrainerControl()
680
681
682
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
683
        self.hp_search_backend = None
684
        self.use_tune_checkpoints = False
685
        default_label_names = find_labels(self.model.__class__)
686
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
687
        self.can_return_loss = can_return_loss(self.model.__class__)
Sylvain Gugger's avatar
Sylvain Gugger committed
688
689
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

690
691
692
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

693
694
695
        # very last
        self._memory_tracker.stop_and_update_metrics()

696
697
        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
698
            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
699

Sylvain Gugger's avatar
Sylvain Gugger committed
700
701
    def add_callback(self, callback):
        """
702
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
703
704

        Args:
705
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
706
707
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
708
709
710
711
712
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
713
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
714

715
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
716
717

        Args:
718
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
719
720
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
721
722

        Returns:
723
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
724
725
726
727
728
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
729
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
730
731

        Args:
732
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
733
734
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
735
736
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
737

Sylvain Gugger's avatar
Sylvain Gugger committed
738
739
740
741
742
743
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

744
    def _set_signature_columns_if_needed(self):
745
746
747
748
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
749
750
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
751

752
753
754
755
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
756
        signature_columns = self._signature_columns
757
758

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
759
        if len(ignored_columns) > 0:
760
            dset_description = "" if description is None else f"in the {description} set"
761
762
763
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
764
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
765
                " you can safely ignore this message."
766
            )
767

768
        columns = [k for k in signature_columns if k in dataset.column_names]
769

770
771
772
773
774
775
776
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
777

778
779
780
781
782
783
784
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
785
        signature_columns = self._signature_columns
786
787
788
789
790
791
792
793
794
795

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

796
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
797
        if self.train_dataset is None or not has_length(self.train_dataset):
798
            return None
799

800
        generator = None
801
        if self.args.world_size <= 1:
802
            generator = torch.Generator()
803
804
805
806
807
808
809
810
811
812
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
813

814
815
        # Build the sampler.
        if self.args.group_by_length:
816
817
818
819
820
821
822
823
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
824
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
825
            if self.args.world_size <= 1:
826
                return LengthGroupedSampler(
827
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
828
                    dataset=self.train_dataset,
829
830
831
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
832
                )
833
834
            else:
                return DistributedLengthGroupedSampler(
835
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
836
                    dataset=self.train_dataset,
837
838
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
839
                    lengths=lengths,
840
                    model_input_name=model_input_name,
841
                    seed=seed,
842
843
844
                )

        else:
845
            if self.args.world_size <= 1:
846
                return RandomSampler(self.train_dataset, generator=generator)
Sylvain Gugger's avatar
Sylvain Gugger committed
847
848
849
850
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
851
852
853
854
855
856
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
857
                    seed=seed,
858
                )
859
            else:
860
                return DistributedSampler(
861
862
863
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
864
                    seed=seed,
865
                )
866
867
868

    def get_train_dataloader(self) -> DataLoader:
        """
869
        Returns the training [`~torch.utils.data.DataLoader`].
870

871
872
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
873
874
875
876
877

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
878

879
        train_dataset = self.train_dataset
880
        data_collator = self.data_collator
881
882
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
883
884
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
885

886
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
887
888
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
889
                    train_dataset,
890
                    batch_size=self._train_batch_size,
891
892
893
894
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
895

896
897
            return DataLoader(
                train_dataset,
898
                batch_size=self._train_batch_size,
899
                collate_fn=data_collator,
900
901
902
903
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

904
905
906
        train_sampler = self._get_train_sampler()

        return DataLoader(
907
            train_dataset,
908
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
909
            sampler=train_sampler,
910
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
911
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
912
            num_workers=self.args.dataloader_num_workers,
913
            pin_memory=self.args.dataloader_pin_memory,
914
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
915
916
        )

917
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
918
919
920
921
922
923
924
925
926
927
928
929
930
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
931
            elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
932
933
934
935
936
937
938
939
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
940
941
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
942
943
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
944
            )
Lysandre Debut's avatar
Lysandre Debut committed
945

Julien Chaumond's avatar
Julien Chaumond committed
946
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
947
        """
948
        Returns the evaluation [`~torch.utils.data.DataLoader`].
949

950
951
        Subclass and override this method if you want to inject some custom behavior.

952
        Args:
953
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
954
955
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
956
        """
Julien Chaumond's avatar
Julien Chaumond committed
957
958
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
959
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
960
        data_collator = self.data_collator
961

962
963
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
964
965
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
966

967
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
968
969
970
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
971
                    batch_size=self.args.per_device_eval_batch_size,
972
973
974
975
976
977
978
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
979
                collate_fn=data_collator,
980
981
982
983
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

984
        eval_sampler = self._get_eval_sampler(eval_dataset)
985

986
        return DataLoader(
987
            eval_dataset,
988
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
989
            batch_size=self.args.eval_batch_size,
990
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
991
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
992
            num_workers=self.args.dataloader_num_workers,
993
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
994
995
996
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
997
        """
998
        Returns the test [`~torch.utils.data.DataLoader`].
999

1000
1001
        Subclass and override this method if you want to inject some custom behavior.

1002
        Args:
1003
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1004
1005
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
1006
        """
1007
1008
        data_collator = self.data_collator

1009
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
1010
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
1011
1012
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
1013

1014
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
1026
                collate_fn=data_collator,
1027
1028
1029
1030
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

1031
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
1032

1033
1034
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
1035
            test_dataset,
1036
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1037
            batch_size=self.args.eval_batch_size,
1038
            collate_fn=data_collator,
1039
            drop_last=self.args.dataloader_drop_last,
1040
            num_workers=self.args.dataloader_num_workers,
1041
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1042
        )
Lysandre Debut's avatar
Lysandre Debut committed
1043

1044
    def create_optimizer_and_scheduler(self, num_training_steps: int):
1045
1046
1047
        """
        Setup the optimizer and the learning rate scheduler.

1048
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
1049
1050
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
1051
1052
        """
        self.create_optimizer()
1053
1054
1055
1056
1057
1058
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
1059
1060
1061
1062
1063

    def create_optimizer(self):
        """
        Setup the optimizer.

1064
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
1065
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1066
        """
1067
1068
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

1069
        if self.optimizer is None:
1070
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
1071
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
1072
1073
            optimizer_grouped_parameters = [
                {
1074
1075
1076
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
1077
1078
1079
                    "weight_decay": self.args.weight_decay,
                },
                {
1080
1081
1082
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
1083
1084
1085
                    "weight_decay": 0.0,
                },
            ]
1086
1087
1088

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

1089
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1090
1091
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
1092
1093
                    optim=optimizer_cls,
                    **optimizer_kwargs,
1094
1095
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1096
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1097
1098
1099
1100
1101
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

Stas Bekman's avatar
Stas Bekman committed
1102
                    skipped = 0
1103
                    for module in opt_model.modules():
1104
                        if isinstance(module, nn.Embedding):
1105
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
Stas Bekman's avatar
Stas Bekman committed
1106
                            print(f"skipped {module}: {skipped/2**20}M params")
1107
1108
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
Stas Bekman's avatar
Stas Bekman committed
1109
                    print(f"skipped: {skipped/2**20}M params")
Sylvain Gugger's avatar
Sylvain Gugger committed
1110

Sylvain Gugger's avatar
Sylvain Gugger committed
1111
1112
1113
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1114
1115
        return self.optimizer

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
1126
1127
1128
1129
1130
1131
1132
1133

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

1134
        optimizer_kwargs = {"lr": args.learning_rate}
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1148
        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
1149
1150
1151
1152
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1153
1154
            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
                optimizer_kwargs.update({"fused": True})
1155
1156
1157
1158
1159
1160
1161
1162
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1163
1164
1165
1166
1167
1168
1169
1170
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1171
1172
1173
1174
1175
1176
1177
1178
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
                        "compensation_buffer_dtype": getattr(
                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
                        ),
                    }
                )
            except ImportError:
                raise ValueError("Please install https://github.com/pytorch/torchdistx")
1199
1200
1201
1202
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1203
1204
1205
1206
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1207
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1208
        """
1209
1210
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1211
1212
1213
1214

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1215
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1216
1217
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1218
                optimizer=self.optimizer if optimizer is None else optimizer,
1219
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1220
                num_training_steps=num_training_steps,
1221
            )
1222
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1223

1224
    def num_examples(self, dataloader: DataLoader) -> int:
1225
        """
1226
1227
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1228
        """
1229
        try:
1230
1231
1232
1233
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1234
1235
1236
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1237

1238
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1239
        """HP search setup code"""
1240
1241
        self._trial = trial

1242
1243
        if self.hp_search_backend is None or trial is None:
            return
1244
1245
1246
1247
1248
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1249
1250
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1251
1252
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1253

1254
1255
        for key, value in params.items():
            if not hasattr(self.args, key):
1256
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1257
1258
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1259
                )
1260
                continue
1261
1262
1263
1264
1265
1266
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1267
            logger.info(f"Trial: {trial.params}")
1268
1269
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1270
1271
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1272
1273
        if self.args.deepspeed:
            # Rebuild the deepspeed config to reflect the updated training parameters
1274
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1275

1276
1277
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1278

1279
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1280
1281
        if self.hp_search_backend is None or trial is None:
            return
1282
        self.objective = self.compute_objective(metrics.copy())
1283
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1284
1285
            import optuna

1286
            trial.report(self.objective, step)
1287
            if trial.should_prune():
1288
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1289
1290
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1291
1292
            from ray import tune

1293
            if self.control.should_save:
1294
                self._tune_save_checkpoint()
1295
1296
            tune.report(objective=self.objective, **metrics)

1297
    def _tune_save_checkpoint(self):
1298
1299
        from ray import tune

1300
1301
        if not self.use_tune_checkpoints:
            return
1302
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1303
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1304
            self.save_model(output_dir, _internal_call=True)
1305
            if self.args.should_save:
1306
1307
1308
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1309

1310
    def call_model_init(self, trial=None):
1311
        model_init_argcount = number_of_arguments(self.model_init)
1312
1313
1314
1315
1316
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1317
1318
1319
1320
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1321
1322
1323

        return model

1324
1325
1326
1327
1328
1329
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            example_batch = next(iter(dataloader))
1330
            example_batch = self._prepare_inputs(example_batch)
1331
1332
            try:
                jit_model = model.eval()
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
                with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
1350
                jit_model = torch.jit.freeze(jit_model)
1351
1352
1353
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
1354
                model = jit_model
1355
1356
1357
                self.use_cpu_amp = False
                self.use_cuda_amp = False
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
1358
1359
1360
1361
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1362
1363
1364
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1365
1366
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1367
1368
1369
1370
1371
1372
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
1373
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
1374
            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
1375
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
1376
1377
1378
        else:
            if not model.training:
                model.train()
1379
1380
1381
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1382
1383
1384

        return model

1385
    def _wrap_model(self, model, training=True, dataloader=None):
1386
1387
1388
1389
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

Sylvain Gugger's avatar
Sylvain Gugger committed
1390
1391
1392
1393
1394
1395
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1396
1397
        # already initialized its own DDP and AMP
        if self.deepspeed:
1398
            return self.deepspeed
1399

1400
1401
1402
1403
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1404
1405
1406
1407
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

1408
1409
        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
1410
            model = nn.DataParallel(model)
1411

1412
        if self.args.jit_mode_eval:
1413
            start_time = time.time()
1414
            model = self.torch_jit_model_eval(model, dataloader, training)
1415
            self.jit_compilation_time = round(time.time() - start_time, 4)
1416

1417
1418
1419
1420
1421
1422
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1423
1424
1425
1426
1427
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1428
                mixed_precision = self.args.fp16 or self.args.bf16
1429
1430
1431
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1432
1433
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1434
                self.model = model = FullyShardedDDP(
1435
1436
1437
1438
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1439
                ).to(self.args.device)
1440
        # Distributed training using PyTorch FSDP
1441
        elif self.fsdp is not None:
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
            if not self.args.fsdp_config["xla"]:
                # PyTorch FSDP!
                from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

                if FSDPOption.OFFLOAD in self.args.fsdp:
                    cpu_offload = CPUOffload(offload_params=True)
                else:
                    cpu_offload = CPUOffload(offload_params=False)
1452

1453
                auto_wrap_policy = None
1454

1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
                if FSDPOption.AUTO_WRAP in self.args.fsdp:
                    if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                        auto_wrap_policy = functools.partial(
                            size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                        )
                    elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                        transformer_cls_to_wrap = set()
                        for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                            transformer_cls = get_module_class_from_name(model, layer_class)
                            if transformer_cls is None:
                                raise Exception("Could not find the transformer layer class to wrap in the model.")
                            else:
                                transformer_cls_to_wrap.add(transformer_cls)
                        auto_wrap_policy = functools.partial(
                            transformer_auto_wrap_policy,
                            # Transformer layer class to wrap
                            transformer_layer_cls=transformer_cls_to_wrap,
                        )
                mixed_precision_policy = None
                dtype = None
                if self.args.fp16:
                    dtype = torch.float16
                elif self.args.bf16:
                    dtype = torch.bfloat16
                if dtype is not None:
                    mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
                if type(model) != FSDP:
                    # XXX: Breaking the self.model convention but I see no way around it for now.
1483
1484
1485
1486
1487
                    signature = inspect.signature(FSDP.__init__).parameters.keys()
                    kwargs = {}
                    for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
                        if arg in signature:
                            kwargs[arg] = getattr(self, arg)
1488
1489
1490
1491
1492
1493
1494
                    self.model = model = FSDP(
                        model,
                        sharding_strategy=self.fsdp,
                        cpu_offload=cpu_offload,
                        auto_wrap_policy=auto_wrap_policy,
                        mixed_precision=mixed_precision_policy,
                        device_id=self.args.device,
1495
                        **kwargs,
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
                    )
            else:
                try:
                    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                    from torch_xla.distributed.fsdp import checkpoint_module
                    from torch_xla.distributed.fsdp.wrap import (
                        size_based_auto_wrap_policy,
                        transformer_auto_wrap_policy,
                    )
                except ImportError:
                    raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
                auto_wrap_policy = None
                auto_wrapper_callable = None
1509
                if self.args.fsdp_config["fsdp_min_num_params"] > 0:
1510
                    auto_wrap_policy = functools.partial(
1511
                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
1512
                    )
1513
1514
1515
1516
1517
1518
1519
1520
                elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                    transformer_cls_to_wrap = set()
                    for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                        transformer_cls = get_module_class_from_name(model, layer_class)
                        if transformer_cls is None:
                            raise Exception("Could not find the transformer layer class to wrap in the model.")
                        else:
                            transformer_cls_to_wrap.add(transformer_cls)
1521
1522
1523
                    auto_wrap_policy = functools.partial(
                        transformer_auto_wrap_policy,
                        # Transformer layer class to wrap
1524
                        transformer_layer_cls=transformer_cls_to_wrap,
1525
                    )
1526
1527
1528
1529
1530
1531
1532
                fsdp_kwargs = self.args.xla_fsdp_config
                if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                    # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                    def auto_wrapper_callable(m, *args, **kwargs):
                        return FSDP(checkpoint_module(m), *args, **kwargs)

                # Wrap the base model with an outer FSDP wrapper
1533
                self.model = model = FSDP(
1534
1535
                    model,
                    auto_wrap_policy=auto_wrap_policy,
1536
1537
                    auto_wrapper_callable=auto_wrapper_callable,
                    **fsdp_kwargs,
1538
                )
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548

                # Patch `xm.optimizer_step` should not reduce gradients in this case,
                # as FSDP does not need gradient reduction over sharded parameters.
                def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                    loss = optimizer.step(**optimizer_args)
                    if barrier:
                        xm.mark_step()
                    return loss

                xm.optimizer_step = patched_optimizer_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1549
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1550
1551
1552
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1553
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
1554
            kwargs = {}
1555
            if self.args.ddp_find_unused_parameters is not None:
1556
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1557
1558
1559
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1560
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1561
            else:
1562
1563
1564
1565
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1566
            if is_torch_neuroncore_available():
1567
                return model
Wing Lian's avatar
Wing Lian committed
1568
1569
1570
1571
1572
1573
1574
            if any(p.requires_grad for p in model.parameters()):
                model = nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                    output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
                    **kwargs,
                )
1575

1576
1577
1578
1579
1580
        # torch.compile() needs to be called after wrapping the model with FSDP or DDP
        # to ensure that it accounts for the graph breaks required by those wrappers
        if self.args.torch_compile:
            model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)

1581
1582
        return model

1583
1584
    def train(
        self,
1585
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1586
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1587
        ignore_keys_for_eval: Optional[List[str]] = None,
1588
        **kwargs,
1589
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1590
1591
1592
1593
        """
        Main training entry point.

        Args:
1594
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1595
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1596
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1597
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1598
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1599
                The trial run or the hyperparameter dictionary for hyperparameter search.
1600
            ignore_keys_for_eval (`List[str]`, *optional*)
1601
1602
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1603
1604
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1605
        """
1606
1607
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1608
1609
1610
1611

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1612
1613
        args = self.args

1614
1615
        self.is_in_train = True

1616
1617
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1618
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1619
            self._move_model_to_device(self.model, args.device)
1620

1621
1622
1623
1624
1625
1626
1627
1628
1629
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1630
1631
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1632
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1633

1634
        # Model re-init
1635
        model_reloaded = False
1636
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1637
            # Seed must be set before instantiating the model when using model_init.
1638
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1639
1640
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1641
1642
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1643

1644
        # Load potential model checkpoint
1645
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1646
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1647
            if resume_from_checkpoint is None:
1648
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1649

1650
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None:
1651
            self._load_from_checkpoint(resume_from_checkpoint)
1652

1653
1654
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1655
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1656
                self._move_model_to_device(self.model, args.device)
1657
1658
            self.model_wrapped = self.model

1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self._train_batch_size = batch_size
1673
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1674
        train_dataloader = self.get_train_dataloader()
1675
1676
1677
1678
1679

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1680
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1681
1682
1683
1684
1685

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1686
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1687
            num_examples = self.num_examples(train_dataloader)
1688
1689
1690
1691
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1692
                )
1693
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1694
1695
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1696
            else:
1697
1698
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1699
1700
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1701
            max_steps = args.max_steps
1702
1703
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1704
            num_update_steps_per_epoch = max_steps
1705
            num_examples = total_train_batch_size * args.max_steps
1706
            num_train_samples = args.max_steps * total_train_batch_size
1707
1708
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1709
1710
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1711
            )
Julien Chaumond's avatar
Julien Chaumond committed
1712

1713
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1714
1715
1716
1717
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1718
1719
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1720
1721
1722
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1723

1724
        delay_optimizer_creation = (
1725
1726
1727
1728
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1729
        )
1730
        if args.deepspeed:
1731
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
1732
1733
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
1734
1735
1736
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
1737
1738
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
1739
        elif not delay_optimizer_creation:
1740
1741
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1742
        self.state = TrainerState()
1743
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1744

1745
1746
1747
1748
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1749
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1750

1751
1752
1753
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1754
1755
1756
1757
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1758
1759
1760
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1761
1762
1763
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1764
1765
        # important: at this point:
        # self.model         is the Transformers Model
1766
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1767

Julien Chaumond's avatar
Julien Chaumond committed
1768
1769
        # Train!
        logger.info("***** Running training *****")
1770
1771
1772
1773
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
        logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size:,}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
1774
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1775
1776
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
Julien Chaumond's avatar
Julien Chaumond committed
1777

1778
        self.state.epoch = 0
1779
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1780
1781
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1782
        steps_trained_progress_bar = None
1783

Julien Chaumond's avatar
Julien Chaumond committed
1784
        # Check if continuing training from a checkpoint
1785
        if resume_from_checkpoint is not None and os.path.isfile(
1786
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1787
        ):
1788
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1789
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1790
            if not args.ignore_data_skip:
1791
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1792
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1793
1794
            else:
                steps_trained_in_current_epoch = 0
1795
1796

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1797
1798
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1799
            if not args.ignore_data_skip:
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
                if skip_first_batches is None:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
                        " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
                        " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
                        " training on data already seen by your model."
                    )
                else:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch."
                    )
                if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
1814
1815
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1816

Sylvain Gugger's avatar
Sylvain Gugger committed
1817
1818
1819
1820
1821
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1822
1823
1824
1825
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
1826
1827
1828
1829
1830
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1831
1832
1833
1834
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1835
1836
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1837

1838
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1839
        tr_loss = torch.tensor(0.0).to(args.device)
1840
1841
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1842
        self._globalstep_last_logged = self.state.global_step
1843
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1844

1845
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1846

1847
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1848
        if not args.ignore_data_skip:
1849
            for epoch in range(epochs_trained):
1850
1851
1852
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
1853
                if is_torch_less_than_1_11 or not is_random_sampler:
1854
1855
1856
1857
1858
1859
1860
1861
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1862

1863
        total_batched_samples = 0
1864
        for epoch in range(epochs_trained, num_train_epochs):
1865
1866
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1867
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1868
                train_dataloader.dataset.set_epoch(epoch)
1869

1870
            if is_torch_tpu_available():
1871
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1872
                epoch_iterator = parallel_loader
1873
            else:
1874
                epoch_iterator = train_dataloader
1875

1876
            # Reset the past mems state at the beginning of each epoch if necessary.
1877
            if args.past_index >= 0:
1878
1879
                self._past = None

1880
            steps_in_epoch = (
1881
1882
1883
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1884
            )
1885
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1886

1887
1888
1889
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1890
            rng_to_sync = False
1891
            steps_skipped = 0
1892
1893
            if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
1894
                steps_skipped = steps_trained_in_current_epoch
1895
1896
1897
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

1898
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1899
            for step, inputs in enumerate(epoch_iterator):
1900
                total_batched_samples += 1
1901
1902
1903
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False
Julien Chaumond's avatar
Julien Chaumond committed
1904
1905
1906
1907

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1908
1909
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1910
1911
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1912
                    continue
1913
1914
1915
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1916

1917
1918
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1919

1920
                if (
1921
                    (total_batched_samples % args.gradient_accumulation_steps != 0)
1922
                    and args.parallel_mode == ParallelMode.DISTRIBUTED
1923
                    and args._no_sync_in_gradient_accumulation
Wing Lian's avatar
Wing Lian committed
1924
                    and hasattr(model, "no_sync")
1925
                ):
1926
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1927
                    with model.no_sync():
1928
                        tr_loss_step = self.training_step(model, inputs)
1929
                else:
1930
1931
                    tr_loss_step = self.training_step(model, inputs)

1932
1933
1934
1935
1936
1937
1938
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1939
1940
1941
                else:
                    tr_loss += tr_loss_step

1942
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1943

1944
1945
1946
1947
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

1948
                if total_batched_samples % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1949
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1950
                    steps_in_epoch <= args.gradient_accumulation_steps
1951
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1952
                ):
1953
                    # Gradient clipping
1954
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
1955
1956
                        # deepspeed does its own clipping

1957
                        if self.do_grad_scaling:
1958
1959
1960
1961
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1962
1963
1964
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1965
1966
1967
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1968
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1969
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1970
1971
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1972
                            model.clip_grad_norm_(args.max_grad_norm)
1973
1974
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
1975
                            nn.utils.clip_grad_norm_(
1976
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
1977
                                args.max_grad_norm,
1978
1979
1980
                            )

                    # Optimizer step
1981
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
1982
                    if self.deepspeed:
1983
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
1984
                    elif is_torch_tpu_available():
1985
1986
1987
1988
1989
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1990
                    elif self.do_grad_scaling:
1991
                        scale_before = self.scaler.get_scale()
1992
                        self.scaler.step(self.optimizer)
1993
                        self.scaler.update()
1994
1995
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1996
                    else:
1997
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
1998

1999
                    if optimizer_was_run and not self.deepspeed:
2000
2001
                        self.lr_scheduler.step()

2002
                    model.zero_grad()
2003
                    self.state.global_step += 1
2004
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2005
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
2006

2007
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
2008
2009
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2010

Sylvain Gugger's avatar
Sylvain Gugger committed
2011
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
2012
                    break
2013
2014
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2015
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
2016
2017
2018
2019
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
2020

2021
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
2022
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
2023

2024
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
2025
2026
2027
2028
2029
2030
2031
2032
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2033
            if self.control.should_training_stop:
2034
                break
Julien Chaumond's avatar
Julien Chaumond committed
2035

2036
        if args.past_index and hasattr(self, "_past"):
2037
2038
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2039
2040

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
2041
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
2042
2043
2044
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
2045
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
2046
                dist.barrier()
2047
2048
            elif is_sagemaker_mp_enabled():
                smp.barrier()
2049

2050
            self._load_best_model()
2051

2052
2053
2054
2055
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

2056
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
2057
2058
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
2059
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2060

2061
        self.is_in_train = False
2062

2063
2064
        self._memory_tracker.stop_and_update_metrics(metrics)

2065
2066
        self.log(metrics)

raghavanone's avatar
raghavanone committed
2067
2068
2069
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

2070
2071
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
raghavanone's avatar
raghavanone committed
2072
2073
2074
2075
2076
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

2077
2078
2079
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
2080

raghavanone's avatar
raghavanone committed
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

2101
2102
2103
2104
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

2105
2106
2107
2108
2109
2110
2111
2112
2113
        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)

        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

        if not any(
            [os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]]
2114
        ):
2115
2116
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

2117
        logger.info(f"Loading model from {resume_from_checkpoint}.")
2118

2119
2120
        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
2121
2122
2123
2124
2125
2126
2127
2128
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

2129
        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
2130
            # If the model is on the GPU, it still works!
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
2145
                    state_dict = torch.load(weights_file, map_location="cpu")
2146
2147
2148
2149
2150
2151
2152
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
2153
2154
2155
2156
2157
                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
                else:
                    state_dict = torch.load(weights_file, map_location="cpu")

2158
2159
2160
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
2161
2162
                # release memory
                del state_dict
2163
                self._issue_warnings_after_load(load_result)
2164
2165
        else:
            # We load the sharded checkpoint
2166
2167
2168
            load_result = load_sharded_checkpoint(
                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
            )
2169
            if not is_sagemaker_mp_enabled():
2170
                self._issue_warnings_after_load(load_result)
2171
2172
2173
2174

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
2175
        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
2176
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2177
        if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
2178
            if self.deepspeed:
2179
2180
2181
2182
2183
                if self.model_wrapped is not None:
                    # this removes the pre-hooks from the previous engine
                    self.model_wrapped.destroy()
                    self.model_wrapped = None

2184
                # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
2185
2186
2187
2188
2189
                deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
                    self,
                    num_training_steps=self.args.max_steps,
                    resume_from_checkpoint=self.state.best_model_checkpoint,
                )
2190
2191
2192
2193
2194
2195
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
                self.optimizer = optimizer
                self.lr_scheduler = lr_scheduler
            else:
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
2209
2210
2211
2212
2213
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")

2214
2215
2216
2217
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
                else:
                    # We load the model state dict on the CPU to avoid an OOM error.
2218
2219
2220
2221
2222
                    if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                        state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                    else:
                        state_dict = torch.load(best_model_path, map_location="cpu")

2223
                    # If the model is on the GPU, it still works!
2224
2225
2226
                    # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                    # which takes *args instead of **kwargs
                    load_result = model.load_state_dict(state_dict, False)
2227
                if not is_sagemaker_mp_enabled():
2228
                    self._issue_warnings_after_load(load_result)
2229
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2230
2231
2232
2233
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2234
                self._issue_warnings_after_load(load_result)
2235
2236
2237
2238
2239
2240
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2241
    def _issue_warnings_after_load(self, load_result):
2242
        if len(load_result.missing_keys) != 0:
2243
2244
2245
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2246
2247
                self.model.tie_weights()
            else:
2248
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2249
        if len(load_result.unexpected_keys) != 0:
2250
2251
2252
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2253

2254
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2255
        if self.control.should_log:
2256
2257
2258
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2259
            logs: Dict[str, float] = {}
2260
2261
2262
2263

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2264
2265
2266
            # reset tr_loss to zero
            tr_loss -= tr_loss

2267
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2268
            logs["learning_rate"] = self._get_learning_rate()
2269

2270
            self._total_loss_scalar += tr_loss_scalar
2271
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2272
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2273
2274
2275
2276
2277

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2278
            if isinstance(self.eval_dataset, dict):
2279
                metrics = {}
2280
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
2281
                    dataset_metrics = self.evaluate(
2282
2283
2284
2285
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
2286
                    metrics.update(dataset_metrics)
2287
2288
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2289
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2290

Sylvain Gugger's avatar
Sylvain Gugger committed
2291
2292
2293
2294
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2295
2296
2297
2298
2299
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2300
2301
2302
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2303
            if not os.path.isfile(rng_file):
2304
                logger.info(
2305
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2306
2307
2308
2309
2310
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2311
            if not os.path.isfile(rng_file):
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
2323
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2324
2325
                torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
            else:
2326
2327
2328
                try:
                    torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
                except Exception as e:
2329
                    logger.info(
2330
2331
2332
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2333
2334
2335
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2336
    def _save_checkpoint(self, model, trial, metrics=None):
2337
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2338
        # want to save except FullyShardedDDP.
2339
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2340

2341
        # Save model checkpoint
2342
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2343

raghavanone's avatar
raghavanone committed
2344
        if self.hp_search_backend is None and trial is None:
2345
            self.store_flos()
2346

raghavanone's avatar
raghavanone committed
2347
        run_dir = self._get_output_dir(trial=trial)
2348
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2349
        self.save_model(output_dir, _internal_call=True)
2350
        if self.deepspeed:
2351
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2352
            # config `stage3_gather_16bit_weights_on_model_save` is True
2353
            self.deepspeed.save_checkpoint(output_dir)
2354
2355

        # Save optimizer and scheduler
2356
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2357
            self.optimizer.consolidate_state_dict()
2358

2359
2360
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2361
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2362
            with warnings.catch_warnings(record=True) as caught_warnings:
2363
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2364
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2365
        elif is_sagemaker_mp_enabled():
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2381
        elif self.args.should_save and not self.deepspeed:
2382
            # deepspeed.save_checkpoint above saves model/optim/sched
2383
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2384
            with warnings.catch_warnings(record=True) as caught_warnings:
2385
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2386
            reissue_pt_warnings(caught_warnings)
2387
            if self.do_grad_scaling:
2388
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2389
2390

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2391
        if metrics is not None and self.args.metric_for_best_model is not None:
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2407
        if self.args.should_save:
2408
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2409

2410
2411
2412
2413
2414
2415
2416
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
2417
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2418
2419
2420
2421
2422
2423
2424
2425
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2426
2427
2428
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2429

2430
        if self.args.world_size <= 1:
2431
2432
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2433
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2434

2435
2436
2437
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2438
        # Maybe delete some older checkpoints.
2439
        if self.args.should_save:
2440
2441
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2442
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2443
        """If optimizer and scheduler states exist, load them."""
2444
        if checkpoint is None:
2445
2446
            return

2447
        if self.deepspeed:
2448
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2449
2450
            return

2451
2452
2453
2454
2455
2456
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2457
2458
2459
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2460
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2461
                with warnings.catch_warnings(record=True) as caught_warnings:
2462
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2463
2464
2465
2466
2467
2468
2469
2470
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
2471
                if is_sagemaker_mp_enabled():
2472
2473
2474
2475
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2476

2477
2478
2479
2480
2481
2482
2483
2484
2485
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2486
2487
2488

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
2489
2490
2491
2492
                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
2493
                    self.optimizer.load_state_dict(
2494
                        torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
2495
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2496
                with warnings.catch_warnings(record=True) as caught_warnings:
2497
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2498
                reissue_pt_warnings(caught_warnings)
2499
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2500
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2501

2502
2503
2504
2505
2506
2507
2508
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2509
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2510
        **kwargs,
2511
2512
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2513
2514
2515
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2516

2517
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2518

Sylvain Gugger's avatar
Sylvain Gugger committed
2519
2520
2521
2522
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2523
2524

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2525

2526
        Args:
2527
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2528
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2529
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2530
2531
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2532
2533
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2534
            n_trials (`int`, *optional*, defaults to 100):
2535
                The number of trial runs to test.
2536
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2537
2538
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2539
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2540
2541
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2542
2543
2544
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2545
2546
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2547

Sylvain Gugger's avatar
Sylvain Gugger committed
2548
2549
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2550
2551
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2552
2553

        Returns:
2554
2555
            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
            `run_summary` attribute for Ray backend.
2556
2557
2558
2559
2560
2561
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2562
2563
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2564
                    "To install sigopt run `pip install sigopt`."
2565
2566
2567
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2568
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2569
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2570
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2571
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2572
            )
2573
2574
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2575
2576
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2577
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2578
2579
2580
2581
2582
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2583
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2584
        self.hp_name = hp_name
2585
2586
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2587
2588
2589
2590
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2591
            HPSearchBackend.WANDB: run_hp_search_wandb,
2592
2593
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2594
2595
2596
2597

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2598
    def log(self, logs: Dict[str, float]) -> None:
2599
        """
2600
        Log `logs` on the various objects watching training.
2601
2602
2603
2604

        Subclass and override this method to inject custom behavior.

        Args:
2605
            logs (`Dict[str, float]`):
2606
2607
                The values to log.
        """
2608
        if self.state.epoch is not None:
2609
            logs["epoch"] = round(self.state.epoch, 2)
2610

2611
2612
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2613
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2614

2615
2616
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2617
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2618
        """
2619
2620
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2621
2622
2623
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
2624
            kwargs = {"device": self.args.device}
2625
2626
            if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)):
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2627
2628
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
2629
                kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()})
2630
2631
2632
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2633
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2634
        """
2635
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2636
2637
        handling potential state.
        """
2638
        inputs = self._prepare_input(inputs)
2639
2640
2641
2642
2643
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2644
2645
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2646

2647
2648
        return inputs

2649
2650
2651
2652
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
2653
        return self.autocast_smart_context_manager()
2654

2655
    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
2656
        """
2657
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2658
2659
        arguments, depending on the situation.
        """
2660
        if self.use_cuda_amp or self.use_cpu_amp:
2661
            if is_torch_greater_or_equal_than_1_10:
2662
                ctx_manager = (
2663
                    torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2664
                    if self.use_cpu_amp
2665
                    else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2666
                )
2667
            else:
2668
                ctx_manager = torch.cuda.amp.autocast()
2669
2670
2671
2672
2673
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2674
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2675
        """
2676
        Perform a training step on a batch of inputs.
2677
2678
2679
2680

        Subclass and override to inject custom behavior.

        Args:
2681
            model (`nn.Module`):
2682
                The model to train.
2683
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2684
2685
2686
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2687
                argument `labels`. Check your model's documentation for all accepted arguments.
2688
2689

        Return:
2690
            `torch.Tensor`: The tensor with training loss on this batch.
2691
2692
        """
        model.train()
2693
        inputs = self._prepare_inputs(inputs)
2694

Sylvain Gugger's avatar
Sylvain Gugger committed
2695
        if is_sagemaker_mp_enabled():
2696
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2697
2698
            return loss_mb.reduce_mean().detach().to(self.args.device)

2699
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2700
            loss = self.compute_loss(model, inputs)
2701

Julien Chaumond's avatar
Julien Chaumond committed
2702
2703
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2704

2705
2706
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
2707
2708
            loss = loss / self.args.gradient_accumulation_steps

2709
        if self.do_grad_scaling:
2710
            self.scaler.scale(loss).backward()
2711
        elif self.use_apex:
2712
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2713
                scaled_loss.backward()
2714
        elif self.deepspeed:
2715
2716
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2717
2718
2719
        else:
            loss.backward()

2720
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
2721

2722
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2723
2724
2725
2726
2727
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2728
2729
2730
2731
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2732
2733
        outputs = model(**inputs)
        # Save past state if it exists
2734
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2735
2736
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2737

2738
        if labels is not None:
2739
2740
2741
2742
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2743
        else:
2744
2745
2746
2747
2748
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2749
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2750
2751
2752
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2753

2754
2755
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2756
2757
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2758
        """
2759
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2760

2761
2762
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2763
        Whether or not this process is the global main process (when training in a distributed fashion on several
2764
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2765
        """
2766
2767
2768
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2769
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2770
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2771
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2772

Sylvain Gugger's avatar
Sylvain Gugger committed
2773
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2774
        """
2775
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2776

2777
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2778
        """
2779
2780
2781
2782

        if output_dir is None:
            output_dir = self.args.output_dir

2783
        if is_torch_tpu_available():
2784
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2785
2786
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2787
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2788
            state_dict = self.model_wrapped.state_dict()
2789
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2790
                self._save(output_dir, state_dict=state_dict)
2791
2792
2793
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2794
        elif (
2795
2796
2797
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2798
2799
        ):
            state_dict = self.model.state_dict()
2800

2801
            if self.args.should_save:
2802
                self._save(output_dir, state_dict=state_dict)
2803
2804
        elif self.deepspeed:
            # this takes care of everything as long as we aren't under zero3
2805
            if self.args.should_save:
2806
2807
2808
2809
2810
2811
2812
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2813
                if self.args.should_save:
2814
2815
2816
2817
2818
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2819
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2820
2821
                # if false it will not be saved.
                # This must be called on all ranks
2822
                if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
2823
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2824
2825
2826
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2827
2828
                    )
                    self.deepspeed.save_checkpoint(output_dir)
2829

2830
        elif self.args.should_save:
2831
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2832

Sylvain Gugger's avatar
Sylvain Gugger committed
2833
2834
2835
2836
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2837
2838
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2839
        logger.info(f"Saving model checkpoint to {output_dir}")
2840
2841
2842

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2843
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2844
2845
2846
2847

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2848
        if not isinstance(self.model, PreTrainedModel):
2849
2850
2851
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2852
                    is_main_process=self.args.should_save,
2853
2854
2855
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2856
2857
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2858
2859
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2860
        else:
2861
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2862
        if self.tokenizer is not None and self.args.should_save:
2863
            self.tokenizer.save_pretrained(output_dir)
2864

2865
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2866
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2867
2868
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2869
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2870
2871
2872
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2873
2874
2875
            if state_dict is None:
                state_dict = self.model.state_dict()

2876
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2877
2878
2879
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
2880
2881
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2882
2883
2884
2885
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2886
        else:
2887
2888
2889
2890
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

2891
        if self.tokenizer is not None:
2892
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2893
2894

        # Good practice: save your training arguments together with the trained model
2895
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2896

2897
    def store_flos(self):
2898
        # Storing the number of floating-point operations that went into the model
2899
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2900
2901
2902
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2903
2904
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2905
            self.state.total_flos += self.current_flos
2906
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2907

2908
2909
2910
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2911
2912
        ordering_and_checkpoint_path = []

2913
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2914
2915
2916
2917
2918

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2919
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2920
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2921
2922
2923
2924
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2925
2926
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2927
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2928
2929
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2930
2931
        return checkpoints_sorted

2932
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2933
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2934
2935
2936
            return

        # Check if we should delete older checkpoint(s)
2937
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2938
2939
2940
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2941
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2952
2953
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2954
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
2955
            shutil.rmtree(checkpoint, ignore_errors=True)
Julien Chaumond's avatar
Julien Chaumond committed
2956

2957
    def evaluate(
2958
2959
2960
2961
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2962
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2963
        """
2964
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2965

Sylvain Gugger's avatar
Sylvain Gugger committed
2966
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2967
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2968

2969
2970
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2971
        Args:
2972
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2973
2974
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
2975
                method.
2976
            ignore_keys (`List[str]`, *optional*):
2977
2978
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2979
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2980
2981
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2982

Julien Chaumond's avatar
Julien Chaumond committed
2983
        Returns:
2984
2985
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2986
        """
2987
2988
2989
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2990
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2991
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2992

2993
2994
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2995
2996
2997
2998
2999
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
3000
            ignore_keys=ignore_keys,
3001
            metric_key_prefix=metric_key_prefix,
3002
        )
Lysandre Debut's avatar
Lysandre Debut committed
3003

3004
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3005
3006
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3007
3008
3009
3010
3011
3012
3013
3014
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3015

3016
        self.log(output.metrics)
3017

3018
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
3019
3020
3021
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
3022
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
3023
3024
3025

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
3026
3027
        return output.metrics

3028
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
3029
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
3030
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3031
        """
3032
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
3033

Sylvain Gugger's avatar
Sylvain Gugger committed
3034
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
3035
        will also return metrics, like in `evaluate()`.
3036
3037

        Args:
3038
3039
3040
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
3041
            ignore_keys (`List[str]`, *optional*):
3042
3043
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3044
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
3045
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
3046
                "test_bleu" if the prefix is "test" (default)
3047

3048
3049
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
3050
3051
3052
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
3053

3054
        </Tip>
3055

3056
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
3057

3058
3059
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
3060
3061
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
3062
        """
3063
3064
3065
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3066
        test_dataloader = self.get_test_dataloader(test_dataset)
3067
        start_time = time.time()
3068

3069
3070
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3071
3072
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
3073
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3074
3075
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3076
3077
3078
3079
3080
3081
3082
3083
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3084

3085
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
3086
3087
        self._memory_tracker.stop_and_update_metrics(output.metrics)

3088
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
3089

3090
    def evaluation_loop(
3091
3092
3093
3094
3095
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
3096
        metric_key_prefix: str = "eval",
3097
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3098
        """
3099
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
3100
3101
3102

        Works both with or without labels.
        """
3103
3104
3105
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
3106

3107
        # if eval is called w/o train init deepspeed here
3108
        if args.deepspeed and not self.deepspeed:
3109
3110
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
3111
3112
3113
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
3114
3115
3116
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
3117

3118
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
3119

3120
3121
3122
3123
3124
3125
3126
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3127

3128
        batch_size = self.args.eval_batch_size
3129

3130
        logger.info(f"***** Running {description} *****")
3131
        if has_length(dataloader):
3132
3133
3134
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
3135
        logger.info(f"  Batch size = {batch_size}")
3136

Julien Chaumond's avatar
Julien Chaumond committed
3137
3138
        model.eval()

3139
3140
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
3141
        eval_dataset = getattr(dataloader, "dataset", None)
3142

3143
        if is_torch_tpu_available():
3144
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3145

3146
        if args.past_index >= 0:
3147
            self._past = None
3148

3149
3150
3151
3152
3153
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
3154
3155
        inputs_host = None

3156
3157
3158
3159
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
3160
        all_inputs = None
3161
3162
3163
3164
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
3165
        for step, inputs in enumerate(dataloader):
3166
3167
3168
3169
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
3170
3171
3172
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
3173
3174

            # Prediction step
3175
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3176
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3177

3178
3179
3180
            if is_torch_tpu_available():
                xm.mark_step()

3181
            # Update containers on host
3182
            if loss is not None:
3183
                losses = self._nested_gather(loss.repeat(batch_size))
3184
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
3185
            if labels is not None:
3186
                labels = self._pad_across_processes(labels)
3187
3188
3189
3190
3191
3192
3193
3194
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3195
3196
3197
3198
            if logits is not None:
                logits = self._pad_across_processes(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
3199
                logits = self._nested_gather(logits)
3200
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
3201
3202
3203
            if labels is not None:
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3204
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
3205

3206
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3207
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3208
3209
3210
3211
3212
3213
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3214
3215
3216
3217
3218
3219
3220
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3221
3222
3223
3224
3225
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3226
3227

                # Set back to None to begin a new accumulation
3228
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3229

3230
        if args.past_index and hasattr(self, "_past"):
3231
3232
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3233

3234
        # Gather all remaining tensors and put them back on the CPU
3235
3236
3237
3238
3239
3240
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3241
3242
3243
3244
3245
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
3246
3247
3248
3249
3250
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
3251
        if has_length(eval_dataset):
3252
            num_samples = len(eval_dataset)
3253
3254
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3255
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3256
3257
            num_samples = eval_dataset.num_examples
        else:
3258
3259
3260
3261
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3262
3263
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3264
3265
3266
3267
3268
3269
3270
3271
3272

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
3273
3274
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
3275
3276
3277

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3278
3279
3280
3281
3282
3283
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3284
3285
        else:
            metrics = {}
3286

3287
3288
3289
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3290
3291
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3292
3293
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
3294

3295
        # Prefix all keys with metric_key_prefix + '_'
3296
        for key in list(metrics.keys()):
3297
3298
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3299

3300
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3301

3302
    def _nested_gather(self, tensors, name=None):
3303
3304
3305
3306
3307
3308
3309
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3310
3311
            if name is None:
                name = "nested_gather"
3312
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3313
3314
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3315
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3316
            tensors = distributed_concat(tensors)
3317
        return tensors
3318

3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
3341
3342
3343
        # When extracting XLA graphs for compilation, max_size is 0,
        # so use inequality to avoid errors.
        if tensor.shape[1] >= max_size:
3344
3345
3346
3347
3348
3349
3350
3351
3352
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3353

3354
    def prediction_step(
3355
3356
3357
3358
3359
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3360
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3361
        """
Stas Bekman's avatar
Stas Bekman committed
3362
        Perform an evaluation step on `model` using `inputs`.
3363
3364
3365
3366

        Subclass and override to inject custom behavior.

        Args:
3367
            model (`nn.Module`):
3368
                The model to evaluate.
3369
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3370
3371
3372
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3373
3374
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3375
                Whether or not to return the loss only.
3376
            ignore_keys (`List[str]`, *optional*):
3377
3378
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3379
3380

        Return:
3381
3382
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3383
        """
3384
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
3385
3386
3387
3388
3389
3390
3391
3392
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

3393
        inputs = self._prepare_inputs(inputs)
3394
3395
3396
3397
3398
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3399

3400
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
3401
        if has_labels or loss_without_labels:
3402
3403
3404
3405
3406
3407
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3408
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3409
3410
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
3411
                if has_labels or loss_without_labels:
Sylvain Gugger's avatar
Sylvain Gugger committed
3412
3413
3414
3415
3416
3417
3418
3419
3420
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3421
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3422
3423
3424
3425
3426
3427
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3428
            else:
3429
                if has_labels or loss_without_labels:
3430
                    with self.compute_loss_context_manager():
3431
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3432
                    loss = loss.mean().detach()
3433

Sylvain Gugger's avatar
Sylvain Gugger committed
3434
3435
3436
3437
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3438
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3439
                    loss = None
3440
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3441
3442
3443
3444
3445
3446
3447
3448
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3449
3450
3451
3452

        if prediction_loss_only:
            return (loss, None, None)

3453
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3454
3455
3456
3457
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3458
3459
3460

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3461
3462
3463
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3464
3465

        Args:
3466
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3467
3468
3469
                The inputs and targets of the model.

        Returns:
3470
            `int`: The number of floating-point operations.
3471
        """
3472
3473
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3474
3475
        else:
            return 0
3476

3477
    def init_git_repo(self, at_init: bool = False):
3478
        """
3479
        Initializes a git repo in `self.args.hub_model_id`.
3480
3481
3482
3483
3484
3485

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3486
        """
3487
        if not self.is_world_process_zero():
3488
            return
3489
        if self.args.hub_model_id is None:
3490
            repo_name = Path(self.args.output_dir).absolute().name
3491
3492
        else:
            repo_name = self.args.hub_model_id
3493
3494
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3495

3496
3497
        # Make sure the repo exists.
        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
3498
        try:
3499
            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3500
        except EnvironmentError:
3501
            if self.args.overwrite_output_dir and at_init:
3502
3503
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
3504
                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3505
3506
3507
3508
            else:
                raise

        self.repo.git_pull()
3509
3510

        # By default, ignore the checkpoint folders
3511
3512
3513
3514
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3515
3516
3517
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3518
3519
3520
3521
        # Add "*.sagemaker" to .gitignore if using SageMaker
        if os.environ.get("SM_TRAINING_ENV"):
            self._add_sm_patterns_to_gitignore()

3522
3523
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3524
3525
3526
3527
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3528
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3529
3530
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3531
3532
3533
3534
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3535
    ):
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3561
3562
3563
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3564
3565
3566
3567
3568
3569
3570
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3571
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3572
3573
3574
3575
3576
3577
3578
3579
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
3590
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3614
3615
3616
            _, self.push_in_progress = self.repo.push_to_hub(
                commit_message=commit_message, blocking=False, auto_lfs_prune=True
            )
3617
3618
3619
3620
3621
3622
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3623
        """
3624
        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3625
3626

        Parameters:
3627
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3628
                Message to commit while pushing.
3629
3630
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3631
            kwargs:
3632
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3633
3634

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3635
3636
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3637
        """
3638
3639
3640
3641
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3642

3643
3644
        model_name = kwargs.pop("model_name", None)
        if model_name is None and self.args.should_save:
3645
3646
3647
3648
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3649

3650
3651
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3652
        self.save_model(_internal_call=True)
3653
3654
3655
3656
3657

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3658
3659
3660
3661
3662
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3663
3664
3665
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3666
3667
3668
3669
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3670
3671
3672
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3673
3674
3675
3676
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3677

3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3689
    ) -> EvalLoopOutput:
3690
        """
3691
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3692
3693
3694

        Works both with or without labels.
        """
3695
3696
        args = self.args

3697
3698
3699
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3700
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3701
3702

        # if eval is called w/o train init deepspeed here
3703
        if args.deepspeed and not self.deepspeed:
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

3716
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3717

3718
3719
3720
3721
3722
3723
3724
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3725
3726
3727
3728
3729
3730
3731
3732
3733

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3734
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3735

3736
        world_size = max(1, args.world_size)
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3747
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3748
3749
3750
3751

        model.eval()

        if is_torch_tpu_available():
3752
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3753

3754
        if args.past_index >= 0:
3755
3756
3757
3758
3759
3760
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3761
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3762

3763
3764
3765
3766
3767
3768
3769
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3770
3771
3772
3773
3774
3775
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3776
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3777
3778

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3779
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3780
3781
3782
3783
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3784
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3785
3786

                # Set back to None to begin a new accumulation
3787
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3788

3789
        if args.past_index and hasattr(self, "_past"):
3790
3791
3792
3793
3794
3795
3796
3797
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3798
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3799
3800
3801
3802

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3803
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3804
3805

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3806
3807
3808
3809
3810
3811
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

3826
        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3839
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3840
3841
3842
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881

    def _add_sm_patterns_to_gitignore(self) -> None:
        """Add SageMaker Checkpointing patterns to .gitignore file."""
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return

        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Get current .gitignore content
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Add the patterns to .gitignore
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write the .gitignore file if it has changed
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        self.repo.git_add(".gitignore")

        # avoid race condition with git status
        time.sleep(0.5)

        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()