trainer.py 187 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34

35
36
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
37

38
# Integrations must be imported before ML frameworks:
39
40
# isort: off
from .integrations import (
41
    default_hp_search_backend,
42
    get_reporting_integration_callbacks,
43
    hp_params,
44
    is_fairscale_available,
45
    is_optuna_available,
46
    is_ray_tune_available,
47
    is_sigopt_available,
48
    is_wandb_available,
49
50
    run_hp_search_optuna,
    run_hp_search_ray,
51
    run_hp_search_sigopt,
52
    run_hp_search_wandb,
53
)
54

55
56
# isort: on

57
58
import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
59
import torch.distributed as dist
60
from huggingface_hub import Repository, create_repo
61
62
from packaging import version
from torch import nn
63
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
64
65
from torch.utils.data.distributed import DistributedSampler

66
67
from . import __version__
from .configuration_utils import PretrainedConfig
68
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
69
from .debug_utils import DebugOption, DebugUnderflowOverflow
70
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
71
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
72
from .modelcard import TrainingSummary
73
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
74
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
75
from .optimization import Adafactor, get_scheduler
76
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
77
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
78
79
80
81
82
83
84
85
86
87
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
88
    DistributedLengthGroupedSampler,
89
    DistributedSamplerWithLoop,
90
    DistributedTensorGatherer,
91
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
92
    LabelSmoother,
93
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
94
    SequentialDistributedSampler,
95
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
    distributed_broadcast_scalars,
    distributed_concat,
98
    find_batch_size,
99
    get_model_param_count,
100
    get_module_class_from_name,
101
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
102
103
104
    nested_concat,
    nested_detach,
    nested_numpify,
105
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
108
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
109
110
111
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
112
    EvalLoopOutput,
113
    EvalPrediction,
114
    FSDPOption,
115
    HPSearchBackend,
116
117
    HubStrategy,
    IntervalStrategy,
118
    PredictionOutput,
119
    RemoveColumnsCollator,
120
    ShardedDDPOption,
121
    TrainerMemoryTracker,
122
123
124
    TrainOutput,
    default_compute_objective,
    default_hp_space,
125
    denumpify_detensorize,
126
    enable_full_determinism,
127
    find_executable_batch_size,
128
    get_last_checkpoint,
129
    has_length,
130
    number_of_arguments,
131
    seed_worker,
132
    set_seed,
133
    speed_metrics,
134
)
135
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
136
137
from .utils import (
    CONFIG_NAME,
138
139
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
140
    WEIGHTS_INDEX_NAME,
141
    WEIGHTS_NAME,
142
    can_return_loss,
143
    find_labels,
144
    get_full_repo_name,
145
    is_accelerate_available,
146
147
148
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
149
    is_ipex_available,
150
    is_peft_available,
151
    is_safetensors_available,
152
153
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
154
    is_torch_compile_available,
155
    is_torch_neuroncore_available,
156
157
    is_torch_tpu_available,
    logging,
158
    strtobool,
159
)
160
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
161
162


163
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
164

Sylvain Gugger's avatar
Sylvain Gugger committed
165
DEFAULT_CALLBACKS = [DefaultFlowCallback]
166
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
167

168
169
170
171
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
172

173
174
if is_apex_available():
    from apex import amp
175

176
177
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
178

179
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
180
181
182
183
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

184
if is_fairscale_available():
185
    dep_version_check("fairscale")
186
    import fairscale
187
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
188
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
189
    from fairscale.nn.wrap import auto_wrap
190
191
192
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

193

Sylvain Gugger's avatar
Sylvain Gugger committed
194
195
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
196
197
198
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
199
200

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
201
202
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
203

204

205
206
207
208
if is_safetensors_available():
    import safetensors.torch


209
210
211
212
if is_peft_available():
    from peft import PeftModel


213
214
215
216
217
218
219
skip_first_batches = None
if is_accelerate_available():
    from accelerate import __version__ as accelerate_version

    if version.parse(accelerate_version) >= version.parse("0.16"):
        from accelerate import skip_first_batches

220
    from accelerate import Accelerator
221
    from accelerate.utils import DistributedDataParallelKwargs
222

223

224
225
226
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
227
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
228
229


230
231
232
233
234
235
236
237
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
238
239
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
240
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
241
242

    Args:
243
244
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
245

246
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
247

Sylvain Gugger's avatar
Sylvain Gugger committed
248
249
250
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.
251
252
253
254

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
255
256
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
257
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
258
259
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
260
261
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
262
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
263
264
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
265
266
267
268
269
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
270
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
271
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
272
273
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
274
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
275
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
276
277
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
278
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
279
280
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
281

282
283
284
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
285
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
286
287
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
288
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
289
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
290
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
291

292
293
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
294
295
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
296
297
298
299
300
301
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
302

303
304
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
305
306
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
307
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
308
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
309
310
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
311
312
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
313
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
314
315
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
316
317
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
318

Julien Chaumond's avatar
Julien Chaumond committed
319
320
    """

321
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
322

Julien Chaumond's avatar
Julien Chaumond committed
323
324
    def __init__(
        self,
325
        model: Union[PreTrainedModel, nn.Module] = None,
326
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
327
328
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
329
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
330
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
331
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
332
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
333
        callbacks: Optional[List[TrainerCallback]] = None,
334
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
335
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
336
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
337
        if args is None:
338
339
340
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
341
342
        self.args = args
        # Seed must be set before instantiating the model when using model
343
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
344
        self.hp_name = None
345
        self.is_in_train = False
346

347
        self.create_accelerator_and_postprocess()
348

349
350
351
352
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

353
        # set the correct log level depending on the node
354
        log_level = args.get_process_log_level()
355
356
        logging.set_verbosity(log_level)

357
358
359
        # force device and distributed setup init explicitly
        args._setup_devices

360
361
362
363
364
365
366
367
368
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
369
370
371
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
372
373
374
                    FutureWarning,
                )
            self.model_init = model_init
375

376
377
378
379
380
381
382
383
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

384
385
386
387
388
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

389
390
391
392
393
394
        if getattr(model, "hf_device_map", None) is not None:
            devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
            if len(devices) > 1:
                self.is_model_parallel = True
            else:
                self.is_model_parallel = self.args.device != torch.device(devices[0])
395
396
397
398
399
400
401

            # warn users
            logger.info(
                "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
                " to `True` to avoid any unexpected behavior such as device placement mismatching."
            )

402
        # At this stage the model is already loaded
403
404
        if getattr(model, "is_quantized", False):
            if getattr(model, "_is_quantized_training_enabled", False):
405
406
407
408
409
410
411
412
413
414
415
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )
416

417
418
419
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
420
            if self.is_deepspeed_enabled:
421
422
423
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
424
425
426
427
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
428
            if args.parallel_mode != ParallelMode.DISTRIBUTED:
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

444
445
        self.fsdp = None
        if len(args.fsdp) > 0:
446
            if self.is_deepspeed_enabled:
447
448
449
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
450
            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
451
452
                raise ValueError("Using fsdp only works in distributed training.")

453
454
455
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
456
            # below is the current alternative.
457
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
458
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
459

460
            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
461
462
463
464
465

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
466
467
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
468

469
            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
470
            if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
471
472
                "backward_prefetch", []
            ):
473
474
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

Seung-Moo Yang's avatar
Seung-Moo Yang committed
475
476
477
            self.forward_prefetch = False
            if self.args.fsdp_config.get("forward_prefect", False):
                self.forward_prefetch = True
478

479
480
481
482
            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

483
        # one place to sort out whether to place the model on device or not
484
485
486
487
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
488
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
489
        # 4. Sharded DDP - same as MP
490
        # 5. FSDP - same as MP
491
        self.place_model_on_device = args.place_model_on_device
492
493
        if (
            self.is_model_parallel
494
            or self.is_deepspeed_enabled
495
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
496
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
497
            or (self.fsdp is not None)
498
            or self.is_fsdp_enabled
499
        ):
500
501
            self.place_model_on_device = False

502
503
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
504
505
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
506
        self.tokenizer = tokenizer
507

508
        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
Sylvain Gugger's avatar
Sylvain Gugger committed
509
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
510
511
512

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
513
            self.args._n_gpu = 1
514
515
516
517
518

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
519
        self.compute_metrics = compute_metrics
520
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
521
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
522
523
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
524
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
525
526
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
542
        if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
543
544
545
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
546
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
547
548
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
549
550
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
551
552
553
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
554
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
555

556
557
558
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

559
560
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
561
            self.init_git_repo(at_init=True)
562
563
564
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
565
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
566
567
                dist.barrier()

568
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
569
            os.makedirs(self.args.output_dir, exist_ok=True)
570

571
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
572
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
573

574
575
576
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

577
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
578
579
580
581
            raise ValueError(
                "The train_dataset does not implement __len__, max_steps has to be specified. "
                "The number of steps needs to be known in advance for the learning rate scheduler."
            )
582

583
584
585
586
587
588
589
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

590
        self._signature_columns = None
591

592
593
        # Mixed precision setup
        self.use_apex = False
594
595
        self.use_cuda_amp = False
        self.use_cpu_amp = False
596

597
598
599
600
601
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
619

620
        if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
621
            if args.half_precision_backend == "auto":
622
                if args.device == torch.device("cpu"):
623
624
625
626
627
628
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
629
                else:
630
                    args.half_precision_backend = "cuda_amp"
631

632
            logger.info(f"Using {args.half_precision_backend} half precision backend")
633

634
        self.do_grad_scaling = False
635
        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
636
            # deepspeed and SageMaker Model Parallel manage their own half precision
637
638
639
640
641
642
643
644
645
646
647
648
649
            if self.sharded_ddp is not None:
                if args.half_precision_backend == "cuda_amp":
                    self.use_cuda_amp = True
                    self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                    #  bf16 does not need grad scaling
                    self.do_grad_scaling = self.amp_dtype == torch.float16
                    if self.do_grad_scaling:
                        if self.sharded_ddp is not None:
                            self.scaler = ShardedGradScaler()
                        elif self.fsdp is not None:
                            from torch.distributed.fsdp.sharded_grad_scaler import (
                                ShardedGradScaler as FSDPShardedGradScaler,
                            )
650

651
652
653
                            self.scaler = FSDPShardedGradScaler()
                        elif is_torch_tpu_available():
                            from torch_xla.amp import GradScaler
654

655
656
657
658
659
660
661
                            self.scaler = GradScaler()
                        else:
                            self.scaler = torch.cuda.amp.GradScaler()
                elif args.half_precision_backend == "cpu_amp":
                    self.use_cpu_amp = True
                    self.amp_dtype = torch.bfloat16
            elif args.half_precision_backend == "apex":
662
663
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
664
665
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
666
667
668
                    )
                self.use_apex = True

669
670
671
672
673
674
675
676
677
678
679
680
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
681
682
683
684
685
686
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

687
688
689
690
691
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
692
        self.control = TrainerControl()
693
694
695
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
696
        self.hp_search_backend = None
697
        self.use_tune_checkpoints = False
698
        default_label_names = find_labels(self.model.__class__)
699
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
700
        self.can_return_loss = can_return_loss(self.model.__class__)
Sylvain Gugger's avatar
Sylvain Gugger committed
701
702
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

703
704
705
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

706
707
708
        # very last
        self._memory_tracker.stop_and_update_metrics()

709
710
        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
711
            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
712

Sylvain Gugger's avatar
Sylvain Gugger committed
713
714
    def add_callback(self, callback):
        """
715
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
716
717

        Args:
718
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
719
720
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
721
722
723
724
725
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
726
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
727

728
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
729
730

        Args:
731
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
732
733
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
734
735

        Returns:
736
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
737
738
739
740
741
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
742
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
743
744

        Args:
745
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
746
747
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
748
749
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
750

Sylvain Gugger's avatar
Sylvain Gugger committed
751
752
753
754
755
756
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

757
    def _set_signature_columns_if_needed(self):
758
759
760
761
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
762
763
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
764

765
766
767
768
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
769
        signature_columns = self._signature_columns
770
771

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
772
        if len(ignored_columns) > 0:
773
            dset_description = "" if description is None else f"in the {description} set"
774
775
776
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
777
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
778
                " you can safely ignore this message."
779
            )
780

781
        columns = [k for k in signature_columns if k in dataset.column_names]
782

783
784
785
786
787
788
789
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
790

791
792
793
794
795
796
797
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
798
        signature_columns = self._signature_columns
799
800
801
802
803
804
805
806
807
808

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

809
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
810
        if self.train_dataset is None or not has_length(self.train_dataset):
811
            return None
812

813
        generator = None
814
        if self.args.world_size <= 1:
815
            generator = torch.Generator()
816
817
818
819
820
821
822
823
824
825
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
826

827
828
        # Build the sampler.
        if self.args.group_by_length:
829
830
831
832
833
834
835
836
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
837
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
838
            if self.args.world_size <= 1:
839
                return LengthGroupedSampler(
840
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
841
                    dataset=self.train_dataset,
842
843
844
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
845
                )
846
847
            else:
                return DistributedLengthGroupedSampler(
848
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
849
                    dataset=self.train_dataset,
850
851
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
852
                    lengths=lengths,
853
                    model_input_name=model_input_name,
854
                    seed=seed,
855
856
857
                )

        else:
858
            if self.args.world_size <= 1:
859
                return RandomSampler(self.train_dataset, generator=generator)
Sylvain Gugger's avatar
Sylvain Gugger committed
860
861
862
863
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
864
865
866
867
868
869
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
870
                    seed=seed,
871
                )
872
            else:
873
                return DistributedSampler(
874
875
876
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
877
                    seed=seed,
878
                )
879
880
881

    def get_train_dataloader(self) -> DataLoader:
        """
882
        Returns the training [`~torch.utils.data.DataLoader`].
883

884
885
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
886
887
888
889
890

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
891

892
        train_dataset = self.train_dataset
893
        data_collator = self.data_collator
894
895
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
896
897
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
898

899
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
900
901
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
902
                    train_dataset,
903
                    batch_size=self._train_batch_size,
904
905
906
907
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
908

909
910
            return DataLoader(
                train_dataset,
911
                batch_size=self._train_batch_size,
912
                collate_fn=data_collator,
913
914
915
916
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

917
918
919
        train_sampler = self._get_train_sampler()

        return DataLoader(
920
            train_dataset,
921
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
922
            sampler=train_sampler,
923
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
924
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
925
            num_workers=self.args.dataloader_num_workers,
926
            pin_memory=self.args.dataloader_pin_memory,
927
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
928
929
        )

930
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
931
932
933
934
935
936
937
938
939
940
941
942
943
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
944
            elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
945
946
947
948
949
950
951
952
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
953
954
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
955
956
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
957
            )
Lysandre Debut's avatar
Lysandre Debut committed
958

Julien Chaumond's avatar
Julien Chaumond committed
959
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
960
        """
961
        Returns the evaluation [`~torch.utils.data.DataLoader`].
962

963
964
        Subclass and override this method if you want to inject some custom behavior.

965
        Args:
966
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
967
968
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
969
        """
Julien Chaumond's avatar
Julien Chaumond committed
970
971
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
972
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
973
        data_collator = self.data_collator
974

975
976
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
977
978
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
979

980
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
981
982
983
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
984
                    batch_size=self.args.per_device_eval_batch_size,
985
986
987
988
989
990
991
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
992
                collate_fn=data_collator,
993
994
995
996
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

997
        eval_sampler = self._get_eval_sampler(eval_dataset)
998

999
        return DataLoader(
1000
            eval_dataset,
1001
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1002
            batch_size=self.args.eval_batch_size,
1003
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
1004
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
1005
            num_workers=self.args.dataloader_num_workers,
1006
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1007
1008
1009
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
1010
        """
1011
        Returns the test [`~torch.utils.data.DataLoader`].
1012

1013
1014
        Subclass and override this method if you want to inject some custom behavior.

1015
        Args:
1016
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1017
1018
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
1019
        """
1020
1021
        data_collator = self.data_collator

1022
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
1023
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
1024
1025
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
1026

1027
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
1039
                collate_fn=data_collator,
1040
1041
1042
1043
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

1044
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
1045

1046
1047
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
1048
            test_dataset,
1049
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1050
            batch_size=self.args.eval_batch_size,
1051
            collate_fn=data_collator,
1052
            drop_last=self.args.dataloader_drop_last,
1053
            num_workers=self.args.dataloader_num_workers,
1054
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1055
        )
Lysandre Debut's avatar
Lysandre Debut committed
1056

1057
    def create_optimizer_and_scheduler(self, num_training_steps: int):
1058
1059
1060
        """
        Setup the optimizer and the learning rate scheduler.

1061
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
1062
1063
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
1064
1065
        """
        self.create_optimizer()
1066
1067
1068
1069
1070
1071
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
1072
1073
1074
1075
1076

    def create_optimizer(self):
        """
        Setup the optimizer.

1077
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
1078
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1079
        """
1080
1081
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

1082
        if self.optimizer is None:
1083
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
1084
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
1085
1086
            optimizer_grouped_parameters = [
                {
1087
1088
1089
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
1090
1091
1092
                    "weight_decay": self.args.weight_decay,
                },
                {
1093
1094
1095
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
1096
1097
1098
                    "weight_decay": 0.0,
                },
            ]
1099
1100
1101

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

1102
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1103
1104
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
1105
1106
                    optim=optimizer_cls,
                    **optimizer_kwargs,
1107
1108
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1109
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1110
1111
1112
1113
1114
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

Stas Bekman's avatar
Stas Bekman committed
1115
                    skipped = 0
1116
                    for module in opt_model.modules():
1117
                        if isinstance(module, nn.Embedding):
1118
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
1119
                            logger.info(f"skipped {module}: {skipped/2**20}M params")
1120
1121
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
1122
                    logger.info(f"skipped: {skipped/2**20}M params")
Sylvain Gugger's avatar
Sylvain Gugger committed
1123

Sylvain Gugger's avatar
Sylvain Gugger committed
1124
1125
1126
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1127
1128
        return self.optimizer

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
1139
1140
1141
1142
1143
1144
1145
1146

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

1147
        optimizer_kwargs = {"lr": args.learning_rate}
1148

1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1161
        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
1162
1163
1164
1165
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1166
1167
            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
                optimizer_kwargs.update({"fused": True})
1168
1169
1170
1171
1172
1173
1174
1175
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1176
1177
1178
1179
1180
1181
1182
1183
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
        elif args.optim in [
            OptimizerNames.ADAMW_BNB,
            OptimizerNames.ADAMW_8BIT,
            OptimizerNames.PAGED_ADAMW,
            OptimizerNames.PAGED_ADAMW_8BIT,
            OptimizerNames.LION,
            OptimizerNames.LION_8BIT,
            OptimizerNames.PAGED_LION,
            OptimizerNames.PAGED_LION_8BIT,
        ]:
            try:
                from bitsandbytes.optim import AdamW, Lion

                is_paged = False
                optim_bits = 32
                optimizer_cls = None
                additional_optim_kwargs = adam_kwargs
                if "paged" in args.optim:
                    is_paged = True
                if "8bit" in args.optim:
                    optim_bits = 8
                if "adam" in args.optim:
                    optimizer_cls = AdamW
                elif "lion" in args.optim:
                    optimizer_cls = Lion
                    additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}

                bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
                optimizer_kwargs.update(additional_optim_kwargs)
                optimizer_kwargs.update(bnb_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
1216
1217
1218
1219
1220
1221
1222
1223
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
                        "compensation_buffer_dtype": getattr(
                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
                        ),
                    }
                )
            except ImportError:
                raise ValueError("Please install https://github.com/pytorch/torchdistx")
1244
1245
1246
1247
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1248
1249
1250
1251
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1252
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1253
        """
1254
1255
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1256
1257
1258
1259

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1260
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1261
1262
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1263
                optimizer=self.optimizer if optimizer is None else optimizer,
1264
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1265
                num_training_steps=num_training_steps,
1266
            )
1267
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1268

1269
    def num_examples(self, dataloader: DataLoader) -> int:
1270
        """
1271
1272
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1273
        """
1274
        try:
1275
1276
1277
1278
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1279
1280
1281
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1282

1283
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1284
        """HP search setup code"""
1285
1286
        self._trial = trial

1287
1288
        if self.hp_search_backend is None or trial is None:
            return
1289
1290
1291
1292
1293
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1294
1295
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1296
1297
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1298

1299
1300
        for key, value in params.items():
            if not hasattr(self.args, key):
1301
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1302
1303
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1304
                )
1305
                continue
1306
1307
1308
1309
1310
1311
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1312
            logger.info(f"Trial: {trial.params}")
1313
1314
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1315
1316
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1317
1318
1319
        if self.is_deepspeed_enabled:
            if self.args.deepspeed is None:
                raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
1320
            # Rebuild the deepspeed config to reflect the updated training parameters
1321
1322
            from accelerate.utils import DeepSpeedPlugin

1323
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1324

1325
1326
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1327
1328
            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
        self.create_accelerator_and_postprocess()
1329

1330
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1331
1332
        if self.hp_search_backend is None or trial is None:
            return
1333
        self.objective = self.compute_objective(metrics.copy())
1334
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1335
1336
            import optuna

1337
            trial.report(self.objective, step)
1338
            if trial.should_prune():
1339
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1340
1341
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1342
1343
            from ray import tune

1344
            if self.control.should_save:
1345
                self._tune_save_checkpoint()
1346
1347
            tune.report(objective=self.objective, **metrics)

1348
    def _tune_save_checkpoint(self):
1349
1350
        from ray import tune

1351
1352
        if not self.use_tune_checkpoints:
            return
1353
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1354
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1355
            self.save_model(output_dir, _internal_call=True)
1356
            if self.args.should_save:
1357
1358
1359
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1360

1361
    def call_model_init(self, trial=None):
1362
        model_init_argcount = number_of_arguments(self.model_init)
1363
1364
1365
1366
1367
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1368
1369
1370
1371
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1372
1373
1374

        return model

1375
1376
1377
1378
1379
1380
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            example_batch = next(iter(dataloader))
1381
            example_batch = self._prepare_inputs(example_batch)
1382
1383
            try:
                jit_model = model.eval()
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
                with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
1401
                jit_model = torch.jit.freeze(jit_model)
1402
1403
1404
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
1405
                model = jit_model
1406
1407
1408
                self.use_cpu_amp = False
                self.use_cuda_amp = False
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
1409
1410
1411
1412
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1413
1414
1415
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1416
1417
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1418
1419
1420
1421
1422
1423
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
1424
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
1425
            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
1426
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
1427
1428
1429
        else:
            if not model.training:
                model.train()
1430
1431
1432
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1433
1434
1435

        return model

1436
    def _wrap_model(self, model, training=True, dataloader=None):
1437
1438
1439
1440
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

Sylvain Gugger's avatar
Sylvain Gugger committed
1441
1442
1443
1444
1445
1446
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1447
1448
1449
1450
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1451
1452
1453
1454
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

1455
1456
        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
1457
            model = nn.DataParallel(model)
1458

1459
        if self.args.jit_mode_eval:
1460
            start_time = time.time()
1461
            model = self.torch_jit_model_eval(model, dataloader, training)
1462
            self.jit_compilation_time = round(time.time() - start_time, 4)
1463

1464
1465
1466
1467
1468
1469
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1470
1471
1472
1473
1474
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1475
                mixed_precision = self.args.fp16 or self.args.bf16
1476
1477
1478
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1479
1480
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1481
                self.model = model = FullyShardedDDP(
1482
1483
1484
1485
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1486
                ).to(self.args.device)
1487
        # Distributed training using PyTorch FSDP
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
        elif self.fsdp is not None and self.args.fsdp_config["xla"]:
            try:
                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                from torch_xla.distributed.fsdp import checkpoint_module
                from torch_xla.distributed.fsdp.wrap import (
                    size_based_auto_wrap_policy,
                    transformer_auto_wrap_policy,
                )
            except ImportError:
                raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
            auto_wrap_policy = None
            auto_wrapper_callable = None
            if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                auto_wrap_policy = functools.partial(
                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                )
            elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                transformer_cls_to_wrap = set()
                for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                    transformer_cls = get_module_class_from_name(model, layer_class)
                    if transformer_cls is None:
                        raise Exception("Could not find the transformer layer class to wrap in the model.")
                    else:
                        transformer_cls_to_wrap.add(transformer_cls)
                auto_wrap_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    # Transformer layer class to wrap
                    transformer_layer_cls=transformer_cls_to_wrap,
1516
                )
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
            fsdp_kwargs = self.args.xla_fsdp_config
            if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                def auto_wrapper_callable(m, *args, **kwargs):
                    return FSDP(checkpoint_module(m), *args, **kwargs)

            # Wrap the base model with an outer FSDP wrapper
            self.model = model = FSDP(
                model,
                auto_wrap_policy=auto_wrap_policy,
                auto_wrapper_callable=auto_wrapper_callable,
                **fsdp_kwargs,
            )
1530

1531
1532
1533
1534
1535
1536
1537
            # Patch `xm.optimizer_step` should not reduce gradients in this case,
            # as FSDP does not need gradient reduction over sharded parameters.
            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                loss = optimizer.step(**optimizer_args)
                if barrier:
                    xm.mark_step()
                return loss
1538

1539
            xm.optimizer_step = patched_optimizer_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1540
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1541
1542
1543
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1544
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
1545
1546
            if is_torch_neuroncore_available():
                return model
1547
            kwargs = {}
1548
            if self.args.ddp_find_unused_parameters is not None:
1549
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1550
1551
1552
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1553
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1554
            else:
1555
1556
1557
1558
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1559
1560

            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
1561
1562
1563

        return model

1564
1565
    def train(
        self,
1566
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1567
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1568
        ignore_keys_for_eval: Optional[List[str]] = None,
1569
        **kwargs,
1570
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1571
1572
1573
1574
        """
        Main training entry point.

        Args:
1575
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1576
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1577
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1578
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1579
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1580
                The trial run or the hyperparameter dictionary for hyperparameter search.
1581
            ignore_keys_for_eval (`List[str]`, *optional*)
1582
1583
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1584
1585
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1586
        """
1587
1588
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1589
1590
1591
1592

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1593
1594
        args = self.args

1595
1596
        self.is_in_train = True

1597
1598
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1599
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1600
            self._move_model_to_device(self.model, args.device)
1601

1602
1603
1604
1605
1606
1607
1608
1609
1610
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1611
1612
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1613
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1614

1615
        # Model re-init
1616
        model_reloaded = False
1617
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1618
            # Seed must be set before instantiating the model when using model_init.
1619
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1620
1621
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1622
1623
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1624

1625
        # Load potential model checkpoint
1626
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1627
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1628
            if resume_from_checkpoint is None:
1629
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1630

1631
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
1632
            self._load_from_checkpoint(resume_from_checkpoint)
1633

1634
1635
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1636
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1637
                self._move_model_to_device(self.model, args.device)
1638
1639
            self.model_wrapped = self.model

1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
1653
        self.accelerator.free_memory()
1654
        self._train_batch_size = batch_size
1655
        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
1656
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1657
        train_dataloader = self.get_train_dataloader()
1658
1659
1660
1661
1662

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1663
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1664
1665
1666
1667
1668

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1669
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1670
            num_examples = self.num_examples(train_dataloader)
1671
1672
1673
1674
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1675
                )
1676
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1677
1678
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1679
            else:
1680
1681
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1682
1683
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1684
            max_steps = args.max_steps
1685
1686
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1687
            num_update_steps_per_epoch = max_steps
1688
            num_examples = total_train_batch_size * args.max_steps
1689
            num_train_samples = args.max_steps * total_train_batch_size
1690
1691
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1692
1693
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1694
            )
Julien Chaumond's avatar
Julien Chaumond committed
1695

1696
1697
1698
1699
1700
1701
1702
1703
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps and args.logging_steps < 1:
            args.logging_steps = math.ceil(max_steps * args.logging_steps)
        if args.eval_steps and args.eval_steps < 1:
            args.eval_steps = math.ceil(max_steps * args.eval_steps)
        if args.save_steps and args.save_steps < 1:
            args.save_steps = math.ceil(max_steps * args.save_steps)

1704
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1705
1706
1707
1708
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1709
1710
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1711
1712
1713
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1714

1715
        delay_optimizer_creation = (
1716
1717
1718
1719
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1720
        )
1721
1722
1723
1724
1725

        if self.is_deepspeed_enabled:
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

        if not delay_optimizer_creation:
1726
1727
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1728
        self.state = TrainerState()
1729
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1730

1731
1732
1733
1734
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1735
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1736

1737
1738
1739
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1740
1741
1742
1743
        # as the model is wrapped, don't use `accelerator.prepare`
        # this is for unhandled cases such as
        # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
        use_accelerator_prepare = True if model is self.model else False
1744

1745
1746
1747
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1748
        # prepare using `accelerator` prepare
1749
        if use_accelerator_prepare:
1750
            model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
1751

1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
        if self.is_fsdp_enabled:
            self.model = model

        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model_wrapped

        # deepspeed ckpt loading
        if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
            deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)

1767
1768
1769
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1770
1771
        # important: at this point:
        # self.model         is the Transformers Model
1772
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1773

Julien Chaumond's avatar
Julien Chaumond committed
1774
1775
        # Train!
        logger.info("***** Running training *****")
1776
1777
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
1778
        logger.info(f"  Instantaneous batch size per device = {self._train_batch_size:,}")
1779
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
1780
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1781
1782
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
Julien Chaumond's avatar
Julien Chaumond committed
1783

1784
        self.state.epoch = 0
1785
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1786
1787
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1788
        steps_trained_progress_bar = None
1789

Julien Chaumond's avatar
Julien Chaumond committed
1790
        # Check if continuing training from a checkpoint
1791
        if resume_from_checkpoint is not None and os.path.isfile(
1792
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1793
        ):
1794
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1795
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1796
            if not args.ignore_data_skip:
1797
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1798
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1799
1800
            else:
                steps_trained_in_current_epoch = 0
1801
1802

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1803
1804
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1805
            if not args.ignore_data_skip:
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
                if skip_first_batches is None:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
                        " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
                        " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
                        " training on data already seen by your model."
                    )
                else:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch."
                    )
                if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
1820
1821
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1822

Sylvain Gugger's avatar
Sylvain Gugger committed
1823
1824
1825
1826
1827
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1828
1829
1830
1831
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
1832
1833
1834
1835
1836
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1837
1838
1839
1840
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1841
1842
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1843

1844
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1845
        tr_loss = torch.tensor(0.0).to(args.device)
1846
1847
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1848
        self._globalstep_last_logged = self.state.global_step
1849
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1850

1851
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1852

1853
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1854
        if not args.ignore_data_skip:
1855
            for epoch in range(epochs_trained):
1856
1857
1858
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
1859
                if is_torch_less_than_1_11 or not is_random_sampler:
1860
1861
1862
1863
1864
1865
1866
1867
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1868

1869
        total_batched_samples = 0
1870
        for epoch in range(epochs_trained, num_train_epochs):
1871
1872
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1873
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1874
                train_dataloader.dataset.set_epoch(epoch)
1875

1876
            if is_torch_tpu_available():
1877
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1878
                epoch_iterator = parallel_loader
1879
            else:
1880
                epoch_iterator = train_dataloader
1881

1882
            # Reset the past mems state at the beginning of each epoch if necessary.
1883
            if args.past_index >= 0:
1884
1885
                self._past = None

1886
            steps_in_epoch = (
1887
1888
1889
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1890
            )
1891
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1892

1893
1894
1895
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1896
            rng_to_sync = False
1897
            steps_skipped = 0
1898
1899
            if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
1900
                steps_skipped = steps_trained_in_current_epoch
1901
1902
1903
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

1904
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1905
            for step, inputs in enumerate(epoch_iterator):
1906
                total_batched_samples += 1
1907
1908
1909
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False
Julien Chaumond's avatar
Julien Chaumond committed
1910
1911
1912
1913

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1914
1915
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1916
1917
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1918
                    continue
1919
1920
1921
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1922

1923
1924
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1925

1926
                with self.accelerator.accumulate(model):
1927
1928
                    tr_loss_step = self.training_step(model, inputs)

1929
1930
1931
1932
1933
1934
1935
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1936
1937
1938
                else:
                    tr_loss += tr_loss_step

1939
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1940

1941
1942
1943
                # should this be under the accumulate context manager?
                # the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered
                # in accelerate
1944
                if total_batched_samples % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1945
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1946
                    steps_in_epoch <= args.gradient_accumulation_steps
1947
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1948
                ):
1949
                    # Gradient clipping
1950
                    if args.max_grad_norm is not None and args.max_grad_norm > 0:
1951
1952
                        # deepspeed does its own clipping

1953
                        if self.do_grad_scaling:
1954
1955
1956
1957
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1958
1959
1960
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1961
1962
1963
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1964
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1965
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1966
1967
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1968
                            model.clip_grad_norm_(args.max_grad_norm)
1969
                        elif self.use_apex:
1970
                            # Revert to normal clipping otherwise, handling Apex or full precision
1971
                            nn.utils.clip_grad_norm_(
1972
1973
1974
1975
1976
1977
                                amp.master_params(self.optimizer),
                                args.max_grad_norm,
                            )
                        else:
                            self.accelerator.clip_grad_norm_(
                                model.parameters(),
1978
                                args.max_grad_norm,
1979
1980
1981
                            )

                    # Optimizer step
1982
                    optimizer_was_run = True
1983
                    if is_torch_tpu_available():
1984
1985
1986
1987
1988
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1989
                    elif self.do_grad_scaling:
1990
                        scale_before = self.scaler.get_scale()
1991
                        self.scaler.step(self.optimizer)
1992
                        self.scaler.update()
1993
1994
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1995
                    else:
1996
                        self.optimizer.step()
1997
                        optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
Lysandre Debut's avatar
Lysandre Debut committed
1998

1999
                    if optimizer_was_run:
2000
2001
2002
                        # Delay optimizer scheduling until metrics are generated
                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                            self.lr_scheduler.step()
2003

2004
                    model.zero_grad()
2005
                    self.state.global_step += 1
2006
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2007
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
2008

2009
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
2010
2011
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2012

Sylvain Gugger's avatar
Sylvain Gugger committed
2013
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
2014
                    break
2015
2016
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2017
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
2018
2019
2020
2021
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
2022

2023
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
2024
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
2025

2026
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
2027
2028
2029
2030
2031
2032
2033
2034
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2035
            if self.control.should_training_stop:
2036
                break
Julien Chaumond's avatar
Julien Chaumond committed
2037

2038
        if args.past_index and hasattr(self, "_past"):
2039
2040
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2041
2042

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
2043
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
2044
2045
2046
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
2047
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
2048
                dist.barrier()
2049
2050
            elif is_sagemaker_mp_enabled():
                smp.barrier()
2051

2052
            self._load_best_model()
2053

2054
2055
2056
2057
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

2058
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
2059
2060
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
2061
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2062

2063
        self.is_in_train = False
2064

2065
2066
        self._memory_tracker.stop_and_update_metrics(metrics)

2067
2068
        self.log(metrics)

raghavanone's avatar
raghavanone committed
2069
2070
2071
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

2072
2073
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
raghavanone's avatar
raghavanone committed
2074
2075
2076
2077
2078
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

2079
2080
2081
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
2082

raghavanone's avatar
raghavanone committed
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

2103
2104
2105
2106
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

2107
2108
2109
2110
2111
2112
2113
2114
        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)

        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

        if not any(
Qingyang Wu's avatar
Qingyang Wu committed
2115
            os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
2116
        ):
2117
2118
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

2119
        logger.info(f"Loading model from {resume_from_checkpoint}.")
2120

2121
2122
        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
2123
2124
2125
2126
2127
2128
2129
2130
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

2131
        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
2132
            # If the model is on the GPU, it still works!
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
2147
                    state_dict = torch.load(weights_file, map_location="cpu")
2148
2149
2150
2151
2152
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
2153
2154
            elif self.is_fsdp_enabled:
                self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint)
2155
2156
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
2157
2158
2159
2160
2161
                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
                else:
                    state_dict = torch.load(weights_file, map_location="cpu")

2162
2163
2164
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
2165
2166
                # release memory
                del state_dict
2167
                self._issue_warnings_after_load(load_result)
2168
2169
        else:
            # We load the sharded checkpoint
2170
2171
2172
            load_result = load_sharded_checkpoint(
                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
            )
2173
            if not is_sagemaker_mp_enabled():
2174
                self._issue_warnings_after_load(load_result)
2175
2176
2177
2178

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
2179
        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
2180
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2181
        if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
2182
2183
            if self.is_deepspeed_enabled:
                deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
2184
            else:
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
2198
2199
2200
2201
2202
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")

2203
2204
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
2205
2206
2207
2208
                elif self.is_fsdp_enabled:
                    self.accelerator.state.fsdp_plugin.load_model(
                        self.accelerator, model, self.state.best_model_checkpoint
                    )
2209
                else:
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
                    if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
                        # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
                        if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
                            if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
                                model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
                                # Load_adapter has no return value present, modify it when appropriate.
                                from torch.nn.modules.module import _IncompatibleKeys

                                load_result = _IncompatibleKeys([], [])
                            else:
                                logger.warning(
                                    "The intermediate checkpoints of PEFT may not be saved correctly, "
                                    "using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
                                    "here are some examples https://github.com/huggingface/peft/issues/96"
                                )
                        else:
                            # We can't do pure 8bit training using transformers.
                            logger.warning("Could not loading a quantized checkpoint.")
2228
                    else:
2229
2230
2231
2232
2233
                        # We load the model state dict on the CPU to avoid an OOM error.
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")
2234

2235
2236
2237
2238
                        # If the model is on the GPU, it still works!
                        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                        # which takes *args instead of **kwargs
                        load_result = model.load_state_dict(state_dict, False)
2239
                if not is_sagemaker_mp_enabled():
2240
                    self._issue_warnings_after_load(load_result)
2241
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2242
2243
2244
2245
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2246
                self._issue_warnings_after_load(load_result)
2247
2248
2249
2250
2251
2252
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2253
    def _issue_warnings_after_load(self, load_result):
2254
        if len(load_result.missing_keys) != 0:
2255
2256
2257
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2258
2259
                self.model.tie_weights()
            else:
2260
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2261
        if len(load_result.unexpected_keys) != 0:
2262
2263
2264
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2265

2266
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2267
        if self.control.should_log:
2268
2269
2270
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2271
            logs: Dict[str, float] = {}
2272
2273
2274
2275

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2276
2277
2278
            # reset tr_loss to zero
            tr_loss -= tr_loss

2279
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2280
            logs["learning_rate"] = self._get_learning_rate()
2281

2282
            self._total_loss_scalar += tr_loss_scalar
2283
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2284
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2285
2286
2287
2288
2289

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2290
            if isinstance(self.eval_dataset, dict):
2291
                metrics = {}
2292
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
2293
                    dataset_metrics = self.evaluate(
2294
2295
2296
2297
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
2298
                    metrics.update(dataset_metrics)
2299
2300
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2301
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2302

2303
2304
            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
2305
2306
2307
2308
                metric_to_check = self.args.metric_for_best_model
                if not metric_to_check.startswith("eval_"):
                    metric_to_check = f"eval_{metric_to_check}"
                self.lr_scheduler.step(metrics[metric_to_check])
2309

Sylvain Gugger's avatar
Sylvain Gugger committed
2310
2311
2312
2313
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2314
2315
2316
2317
2318
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2319
2320
2321
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2322
            if not os.path.isfile(rng_file):
2323
                logger.info(
2324
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2325
2326
2327
2328
2329
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2330
            if not os.path.isfile(rng_file):
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
2342
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2343
                torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
2344
            else:
2345
                try:
2346
                    torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
2347
                except Exception as e:
2348
                    logger.info(
2349
2350
2351
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2352
2353
2354
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2355
    def _save_checkpoint(self, model, trial, metrics=None):
2356
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2357
        # want to save except FullyShardedDDP.
2358
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2359

2360
        # Save model checkpoint
2361
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2362

raghavanone's avatar
raghavanone committed
2363
        if self.hp_search_backend is None and trial is None:
2364
            self.store_flos()
2365

raghavanone's avatar
raghavanone committed
2366
        run_dir = self._get_output_dir(trial=trial)
2367
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2368
        self.save_model(output_dir, _internal_call=True)
2369
        if self.is_deepspeed_enabled:
2370
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2371
            # config `stage3_gather_16bit_weights_on_model_save` is True
2372
            self.model_wrapped.save_checkpoint(output_dir)
2373
2374

        # Save optimizer and scheduler
2375
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2376
            self.optimizer.consolidate_state_dict()
2377

Qingyang Wu's avatar
Qingyang Wu committed
2378
2379
2380
2381
2382
2383
        if self.fsdp:
            # FSDP has a different interface for saving optimizer states.
            # Needs to be called on all ranks to gather all states.
            # full_optim_state_dict will be deprecated after Pytorch 2.2!
            full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)

2384
2385
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2386
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2387
            with warnings.catch_warnings(record=True) as caught_warnings:
2388
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2389
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2390
        elif is_sagemaker_mp_enabled():
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2406
        elif self.args.should_save and not self.is_deepspeed_enabled:
2407
            # deepspeed.save_checkpoint above saves model/optim/sched
Qingyang Wu's avatar
Qingyang Wu committed
2408
2409
2410
2411
2412
            if self.fsdp:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

2413
            with warnings.catch_warnings(record=True) as caught_warnings:
2414
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2415
            reissue_pt_warnings(caught_warnings)
2416
            if self.do_grad_scaling:
2417
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2418
2419

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2420
        if metrics is not None and self.args.metric_for_best_model is not None:
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2436
        if self.args.should_save:
2437
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2438

2439
2440
2441
2442
2443
2444
2445
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
2446
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2447
2448
2449
2450
2451
2452
2453
2454
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2455
2456
2457
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2458

2459
        if self.args.world_size <= 1:
2460
2461
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2462
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2463

2464
2465
2466
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2467
        # Maybe delete some older checkpoints.
2468
        if self.args.should_save:
2469
2470
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2471
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2472
        """If optimizer and scheduler states exist, load them."""
2473
        if checkpoint is None:
2474
2475
            return

2476
        if self.is_deepspeed_enabled:
2477
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2478
2479
            return

2480
2481
2482
2483
2484
2485
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2486
2487
2488
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2489
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2490
                with warnings.catch_warnings(record=True) as caught_warnings:
2491
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2492
2493
2494
2495
2496
2497
2498
2499
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
2500
                if is_sagemaker_mp_enabled():
2501
2502
2503
2504
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2505

2506
2507
2508
2509
2510
2511
2512
2513
2514
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2515
2516
2517

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
2518
2519
2520
2521
                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
Qingyang Wu's avatar
Qingyang Wu committed
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
                    if self.fsdp:
                        full_osd = None
                        # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
                        if self.args.process_index == 0:
                            full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
                        # call scatter_full_optim_state_dict on all ranks
                        sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
                        self.optimizer.load_state_dict(sharded_osd)
                    else:
                        self.optimizer.load_state_dict(
                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                        )
Sylvain Gugger's avatar
Sylvain Gugger committed
2534
                with warnings.catch_warnings(record=True) as caught_warnings:
2535
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2536
                reissue_pt_warnings(caught_warnings)
2537
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2538
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2539

2540
2541
2542
2543
2544
2545
2546
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2547
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2548
        **kwargs,
2549
2550
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2551
2552
2553
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2554

2555
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2556

Sylvain Gugger's avatar
Sylvain Gugger committed
2557
2558
2559
2560
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2561
2562

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2563

2564
        Args:
2565
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2566
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2567
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2568
2569
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2570
2571
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2572
            n_trials (`int`, *optional*, defaults to 100):
2573
                The number of trial runs to test.
2574
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2575
2576
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2577
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2578
2579
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2580
2581
2582
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2583
2584
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2585

Sylvain Gugger's avatar
Sylvain Gugger committed
2586
2587
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2588
2589
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2590
2591

        Returns:
2592
2593
            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
            `run_summary` attribute for Ray backend.
2594
2595
2596
2597
2598
2599
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2600
2601
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2602
                    "To install sigopt run `pip install sigopt`."
2603
2604
2605
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2606
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2607
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2608
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2609
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2610
            )
2611
2612
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2613
2614
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2615
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2616
2617
2618
2619
2620
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2621
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2622
        self.hp_name = hp_name
2623
2624
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2625
2626
2627
2628
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2629
            HPSearchBackend.WANDB: run_hp_search_wandb,
2630
2631
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2632
2633
2634
2635

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2636
    def log(self, logs: Dict[str, float]) -> None:
2637
        """
2638
        Log `logs` on the various objects watching training.
2639
2640
2641
2642

        Subclass and override this method to inject custom behavior.

        Args:
2643
            logs (`Dict[str, float]`):
2644
2645
                The values to log.
        """
2646
        if self.state.epoch is not None:
2647
            logs["epoch"] = round(self.state.epoch, 2)
2648

2649
2650
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2651
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2652

2653
2654
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2655
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2656
        """
2657
2658
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2659
2660
2661
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
2662
            kwargs = {"device": self.args.device}
2663
            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
2664
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2665
2666
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
2667
                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
2668
2669
2670
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2671
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2672
        """
2673
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2674
2675
        handling potential state.
        """
2676
        inputs = self._prepare_input(inputs)
2677
2678
2679
2680
2681
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2682
2683
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2684

2685
2686
        return inputs

2687
2688
2689
2690
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
2691
        return self.autocast_smart_context_manager()
2692

2693
    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
2694
        """
2695
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2696
2697
        arguments, depending on the situation.
        """
2698
        if self.use_cuda_amp or self.use_cpu_amp:
2699
            if is_torch_greater_or_equal_than_1_10:
2700
                ctx_manager = (
2701
                    torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2702
                    if self.use_cpu_amp
2703
                    else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2704
                )
2705
            else:
2706
                ctx_manager = torch.cuda.amp.autocast()
2707
2708
2709
2710
2711
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2712
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2713
        """
2714
        Perform a training step on a batch of inputs.
2715
2716
2717
2718

        Subclass and override to inject custom behavior.

        Args:
2719
            model (`nn.Module`):
2720
                The model to train.
2721
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2722
2723
2724
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2725
                argument `labels`. Check your model's documentation for all accepted arguments.
2726
2727

        Return:
2728
            `torch.Tensor`: The tensor with training loss on this batch.
2729
2730
        """
        model.train()
2731
        inputs = self._prepare_inputs(inputs)
2732

Sylvain Gugger's avatar
Sylvain Gugger committed
2733
        if is_sagemaker_mp_enabled():
2734
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2735
2736
            return loss_mb.reduce_mean().detach().to(self.args.device)

2737
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2738
            loss = self.compute_loss(model, inputs)
2739

Julien Chaumond's avatar
Julien Chaumond committed
2740
2741
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2742

2743
        if self.do_grad_scaling:
2744
            self.scaler.scale(loss).backward()
2745
        elif self.use_apex:
2746
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2747
2748
                scaled_loss.backward()
        else:
2749
            self.accelerator.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2750

2751
        return loss.detach() / self.args.gradient_accumulation_steps
Julien Chaumond's avatar
Julien Chaumond committed
2752

2753
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2754
2755
2756
2757
2758
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2759
2760
2761
2762
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2763
2764
        outputs = model(**inputs)
        # Save past state if it exists
2765
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2766
2767
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2768

2769
        if labels is not None:
2770
2771
2772
2773
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2774
        else:
2775
2776
2777
2778
2779
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2780
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2781
2782
2783
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2784

2785
2786
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2787
2788
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2789
        """
2790
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2791

2792
2793
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2794
        Whether or not this process is the global main process (when training in a distributed fashion on several
2795
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2796
        """
2797
2798
2799
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2800
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2801
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2802
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2803

Sylvain Gugger's avatar
Sylvain Gugger committed
2804
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2805
        """
2806
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2807

2808
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2809
        """
2810
2811
2812
2813

        if output_dir is None:
            output_dir = self.args.output_dir

2814
        if is_torch_tpu_available():
2815
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2816
2817
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2818
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2819
            state_dict = self.model_wrapped.state_dict()
2820
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2821
                self._save(output_dir, state_dict=state_dict)
2822
2823
2824
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2825
        elif (
2826
2827
2828
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2829
            or self.is_fsdp_enabled
2830
        ):
2831
            if self.is_fsdp_enabled:
2832
2833
2834
                self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
            else:
                state_dict = self.model.state_dict()
2835

2836
2837
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
2838
        elif self.is_deepspeed_enabled:
2839
            # this takes care of everything as long as we aren't under zero3
2840
            if self.args.should_save:
2841
2842
2843
2844
2845
2846
2847
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2848
                if self.args.should_save:
2849
2850
2851
2852
2853
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2854
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2855
2856
                # if false it will not be saved.
                # This must be called on all ranks
2857
                if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
2858
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2859
2860
2861
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2862
                    )
2863
                    self.model_wrapped.save_checkpoint(output_dir)
2864

2865
        elif self.args.should_save:
2866
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2867

Sylvain Gugger's avatar
Sylvain Gugger committed
2868
2869
2870
2871
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2872
2873
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2874
        logger.info(f"Saving model checkpoint to {output_dir}")
2875
2876
2877

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2878
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2879
2880
2881
2882

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2883
        if not isinstance(self.model, PreTrainedModel):
2884
2885
2886
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2887
                    is_main_process=self.args.should_save,
2888
2889
2890
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2891
2892
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2893
2894
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2895
        else:
2896
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2897
        if self.tokenizer is not None and self.args.should_save:
2898
            self.tokenizer.save_pretrained(output_dir)
2899

2900
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2901
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2902
2903
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2904
        logger.info(f"Saving model checkpoint to {output_dir}")
2905
2906

        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
Julien Chaumond's avatar
Julien Chaumond committed
2907
2908
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
2909
        if not isinstance(self.model, supported_classes):
2910
2911
2912
            if state_dict is None:
                state_dict = self.model.state_dict()

2913
            if isinstance(unwrap_model(self.model), supported_classes):
2914
2915
2916
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
2917
2918
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2919
2920
2921
2922
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2923
        else:
2924
2925
2926
2927
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

2928
        if self.tokenizer is not None:
2929
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2930
2931

        # Good practice: save your training arguments together with the trained model
2932
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2933

2934
    def store_flos(self):
2935
        # Storing the number of floating-point operations that went into the model
2936
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2937
2938
2939
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2940
2941
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2942
            self.state.total_flos += self.current_flos
2943
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2944

2945
2946
2947
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2948
2949
        ordering_and_checkpoint_path = []

2950
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2951
2952
2953
2954
2955

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2956
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2957
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2958
2959
2960
2961
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2962
2963
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2964
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2965
2966
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2967
2968
        return checkpoints_sorted

2969
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2970
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2971
2972
2973
            return

        # Check if we should delete older checkpoint(s)
2974
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2975
2976
2977
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2978
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2989
2990
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2991
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
2992
            shutil.rmtree(checkpoint, ignore_errors=True)
Julien Chaumond's avatar
Julien Chaumond committed
2993

2994
    def evaluate(
2995
2996
2997
2998
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2999
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
3000
        """
3001
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
3002

Sylvain Gugger's avatar
Sylvain Gugger committed
3003
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
3004
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
3005

3006
3007
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
3008
        Args:
3009
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
3010
3011
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
3012
                method.
3013
            ignore_keys (`List[str]`, *optional*):
3014
3015
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3016
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
3017
3018
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
3019

Julien Chaumond's avatar
Julien Chaumond committed
3020
        Returns:
3021
3022
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
3023
        """
3024
3025
3026
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3027
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
3028
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
3029

3030
3031
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3032
3033
3034
3035
3036
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
3037
            ignore_keys=ignore_keys,
3038
            metric_key_prefix=metric_key_prefix,
3039
        )
Lysandre Debut's avatar
Lysandre Debut committed
3040

3041
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3042
3043
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3044
3045
3046
3047
3048
3049
3050
3051
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3052

3053
        self.log(output.metrics)
3054

3055
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
3056
3057
3058
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
3059
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
3060
3061
3062

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
3063
3064
        return output.metrics

3065
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
3066
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
3067
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3068
        """
3069
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
3070

Sylvain Gugger's avatar
Sylvain Gugger committed
3071
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
3072
        will also return metrics, like in `evaluate()`.
3073
3074

        Args:
3075
3076
3077
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
3078
            ignore_keys (`List[str]`, *optional*):
3079
3080
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3081
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
3082
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
3083
                "test_bleu" if the prefix is "test" (default)
3084

3085
3086
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
3087
3088
3089
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
3090

3091
        </Tip>
3092

3093
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
3094

3095
3096
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
3097
3098
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
3099
        """
3100
3101
3102
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3103
        test_dataloader = self.get_test_dataloader(test_dataset)
3104
        start_time = time.time()
3105

3106
3107
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3108
3109
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
3110
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3111
3112
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3113
3114
3115
3116
3117
3118
3119
3120
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3121

3122
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
3123
3124
        self._memory_tracker.stop_and_update_metrics(output.metrics)

3125
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
3126

3127
    def evaluation_loop(
3128
3129
3130
3131
3132
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
3133
        metric_key_prefix: str = "eval",
3134
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3135
        """
3136
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
3137
3138
3139

        Works both with or without labels.
        """
3140
3141
3142
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
3143

3144
        # if eval is called w/o train init deepspeed here
3145
3146
3147
3148
        if self.is_deepspeed_enabled and self.model_wrapped is self.model:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
            model = self.accelerator.prepare(self.model)
            self.model_wrapped = self.deepspeed = model
3149

3150
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
3151

3152
3153
3154
3155
3156
3157
3158
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3159

3160
        batch_size = self.args.eval_batch_size
3161

3162
        logger.info(f"***** Running {description} *****")
3163
        if has_length(dataloader):
3164
3165
3166
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
3167
        logger.info(f"  Batch size = {batch_size}")
3168

Julien Chaumond's avatar
Julien Chaumond committed
3169
3170
        model.eval()

3171
3172
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
3173
        eval_dataset = getattr(dataloader, "dataset", None)
3174

3175
        if is_torch_tpu_available():
3176
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3177

3178
        if args.past_index >= 0:
3179
            self._past = None
3180

3181
3182
3183
3184
3185
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
3186
3187
        inputs_host = None

3188
3189
3190
3191
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
3192
        all_inputs = None
3193
3194
3195
3196
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
3197
        for step, inputs in enumerate(dataloader):
3198
3199
3200
3201
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
3202
3203
3204
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
3205
3206

            # Prediction step
3207
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3208
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3209

3210
3211
3212
            if is_torch_tpu_available():
                xm.mark_step()

3213
            # Update containers on host
3214
            if loss is not None:
3215
                losses = self._nested_gather(loss.repeat(batch_size))
3216
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
3217
            if labels is not None:
3218
                labels = self._pad_across_processes(labels)
3219
3220
3221
3222
3223
3224
3225
3226
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3227
3228
3229
3230
            if logits is not None:
                logits = self._pad_across_processes(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
3231
                logits = self._nested_gather(logits)
3232
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
3233
3234
3235
            if labels is not None:
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3236
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
3237

3238
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3239
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3240
3241
3242
3243
3244
3245
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3246
3247
3248
3249
3250
3251
3252
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3253
3254
3255
3256
3257
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3258
3259

                # Set back to None to begin a new accumulation
3260
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3261

3262
        if args.past_index and hasattr(self, "_past"):
3263
3264
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3265

3266
        # Gather all remaining tensors and put them back on the CPU
3267
3268
3269
3270
3271
3272
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3273
3274
3275
3276
3277
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
3278
3279
3280
3281
3282
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
3283
        if has_length(eval_dataset):
3284
            num_samples = len(eval_dataset)
3285
3286
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3287
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3288
3289
            num_samples = eval_dataset.num_examples
        else:
3290
3291
3292
3293
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3294
3295
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3296
3297
3298
3299
3300
3301
3302
3303
3304

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
3305
3306
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
3307
3308
3309

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3310
3311
3312
3313
3314
3315
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3316
3317
        else:
            metrics = {}
3318

3319
3320
3321
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3322
3323
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3324
3325
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
3326

3327
        # Prefix all keys with metric_key_prefix + '_'
3328
        for key in list(metrics.keys()):
3329
3330
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3331

3332
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3333

3334
    def _nested_gather(self, tensors, name=None):
3335
3336
3337
3338
3339
3340
3341
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3342
3343
            if name is None:
                name = "nested_gather"
3344
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3345
3346
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
Zachary Mueller's avatar
Zachary Mueller committed
3347
3348
3349
        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
            self.args.distributed_state is None and self.local_rank != -1
        ):
3350
            tensors = distributed_concat(tensors)
3351
        return tensors
3352

3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
3375
3376
3377
        # When extracting XLA graphs for compilation, max_size is 0,
        # so use inequality to avoid errors.
        if tensor.shape[1] >= max_size:
3378
3379
3380
3381
3382
3383
3384
3385
3386
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3387

3388
    def prediction_step(
3389
3390
3391
3392
3393
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3394
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3395
        """
Stas Bekman's avatar
Stas Bekman committed
3396
        Perform an evaluation step on `model` using `inputs`.
3397
3398
3399
3400

        Subclass and override to inject custom behavior.

        Args:
3401
            model (`nn.Module`):
3402
                The model to evaluate.
3403
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3404
3405
3406
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3407
3408
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3409
                Whether or not to return the loss only.
3410
            ignore_keys (`List[str]`, *optional*):
3411
3412
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3413
3414

        Return:
3415
3416
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3417
        """
3418
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
3419
3420
3421
3422
3423
3424
3425
3426
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

3427
        inputs = self._prepare_inputs(inputs)
3428
3429
3430
3431
3432
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3433

3434
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
3435
        if has_labels or loss_without_labels:
3436
3437
3438
3439
3440
3441
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3442
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3443
3444
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
3445
                if has_labels or loss_without_labels:
Sylvain Gugger's avatar
Sylvain Gugger committed
3446
3447
3448
3449
3450
3451
3452
3453
3454
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3455
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3456
3457
3458
3459
3460
3461
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3462
            else:
3463
                if has_labels or loss_without_labels:
3464
                    with self.compute_loss_context_manager():
3465
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3466
                    loss = loss.mean().detach()
3467

Sylvain Gugger's avatar
Sylvain Gugger committed
3468
3469
3470
3471
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3472
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3473
                    loss = None
3474
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3475
3476
3477
3478
3479
3480
3481
3482
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3483
3484
3485
3486

        if prediction_loss_only:
            return (loss, None, None)

3487
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3488
3489
3490
3491
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3492
3493
3494

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3495
3496
3497
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3498
3499

        Args:
3500
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3501
3502
3503
                The inputs and targets of the model.

        Returns:
3504
            `int`: The number of floating-point operations.
3505
        """
3506
3507
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3508
3509
        else:
            return 0
3510

3511
    def init_git_repo(self, at_init: bool = False):
3512
        """
3513
        Initializes a git repo in `self.args.hub_model_id`.
3514
3515
3516
3517
3518
3519

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3520
        """
3521
        if not self.is_world_process_zero():
3522
            return
3523
        if self.args.hub_model_id is None:
3524
            repo_name = Path(self.args.output_dir).absolute().name
3525
3526
        else:
            repo_name = self.args.hub_model_id
3527
3528
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3529

3530
3531
        # Make sure the repo exists.
        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
3532
        try:
3533
            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3534
        except EnvironmentError:
3535
            if self.args.overwrite_output_dir and at_init:
3536
3537
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
3538
                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3539
3540
3541
3542
            else:
                raise

        self.repo.git_pull()
3543
3544

        # By default, ignore the checkpoint folders
3545
3546
3547
3548
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3549
3550
3551
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3552
3553
3554
3555
        # Add "*.sagemaker" to .gitignore if using SageMaker
        if os.environ.get("SM_TRAINING_ENV"):
            self._add_sm_patterns_to_gitignore()

3556
3557
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3558
3559
3560
3561
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3562
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3563
3564
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3565
3566
3567
3568
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3569
    ):
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3595
3596
3597
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3598
3599
3600
3601
3602
3603
3604
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3605
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3606
3607
3608
3609
3610
3611
3612
3613
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
3624
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3648
3649
3650
3651
            push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
            # Return type of `Repository.push_to_hub` is either None or a tuple.
            if push_work is not None:
                self.push_in_progress = push_work[1]
3652
3653
        except Exception as e:
            logger.error(f"Error when pushing to hub: {e}")
3654
3655
3656
3657
3658
3659
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3660
        """
3661
        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3662
3663

        Parameters:
3664
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3665
                Message to commit while pushing.
3666
3667
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3668
            kwargs:
3669
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3670
3671

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3672
3673
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3674
        """
3675
3676
3677
3678
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3679

3680
3681
        model_name = kwargs.pop("model_name", None)
        if model_name is None and self.args.should_save:
3682
3683
3684
3685
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3686

3687
3688
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3689
        self.save_model(_internal_call=True)
3690
3691
3692
3693
3694

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3695
3696
3697
3698
3699
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3700
3701
3702
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3703
3704
3705
3706
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3707
3708
3709
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3710
3711
3712
3713
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3714

3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3726
    ) -> EvalLoopOutput:
3727
        """
3728
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3729
3730
3731

        Works both with or without labels.
        """
3732
3733
        args = self.args

3734
3735
3736
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3737
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3738
3739

        # if eval is called w/o train init deepspeed here
3740
3741
3742
3743
        if self.is_deepspeed_enabled and self.model_wrapped is self.model:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
            model = self.accelerator.prepare(self.model)
            self.model_wrapped = self.deepspeed = model
3744

3745
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3746

3747
3748
3749
3750
3751
3752
3753
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3754
3755
3756
3757
3758
3759
3760
3761
3762

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3763
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3764

3765
        world_size = max(1, args.world_size)
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3776
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3777
3778
3779
3780

        model.eval()

        if is_torch_tpu_available():
3781
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3782

3783
        if args.past_index >= 0:
3784
3785
3786
3787
3788
3789
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3790
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3791

3792
3793
3794
3795
3796
3797
3798
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3799
3800
3801
3802
3803
3804
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3805
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3806
3807

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3808
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3809
3810
3811
3812
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3813
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3814
3815

                # Set back to None to begin a new accumulation
3816
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3817

3818
        if args.past_index and hasattr(self, "_past"):
3819
3820
3821
3822
3823
3824
3825
3826
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3827
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3828
3829
3830
3831

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3832
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3833
3834

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3835
3836
3837
3838
3839
3840
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

3855
        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3868
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3869
3870
3871
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910

    def _add_sm_patterns_to_gitignore(self) -> None:
        """Add SageMaker Checkpointing patterns to .gitignore file."""
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return

        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Get current .gitignore content
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Add the patterns to .gitignore
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write the .gitignore file if it has changed
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        self.repo.git_add(".gitignore")

        # avoid race condition with git status
        time.sleep(0.5)

        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
3932
3933
3934
3935
3936
3937

    def create_accelerator_and_postprocess(self):
        # create accelerator object
        self.accelerator = Accelerator(
            deepspeed_plugin=self.args.deepspeed_plugin,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
        )

        # deepspeed and accelerate flags covering both trainer args and accelerate launcher
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

        # post accelerator creation setup
        if self.is_fsdp_enabled:
            fsdp_plugin = self.accelerator.state.fsdp_plugin
            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
            fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

        if self.is_deepspeed_enabled:
            if getattr(self.args, "hf_deepspeed_config", None) is None:
                from transformers.deepspeed import HfTrainerDeepSpeedConfig

                ds_plugin = self.accelerator.state.deepspeed_plugin

                ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
                ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
                ds_plugin.hf_ds_config.trainer_config_process(self.args)