trainer.py 180 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
32
from distutils.util import strtobool
Julien Chaumond's avatar
Julien Chaumond committed
33
from pathlib import Path
34
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
35

36
37
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
38

39
# Integrations must be imported before ML frameworks:
40
41
# isort: off
from .integrations import (
42
    default_hp_search_backend,
43
    get_reporting_integration_callbacks,
44
    hp_params,
45
    is_fairscale_available,
46
    is_optuna_available,
47
    is_ray_tune_available,
48
    is_sigopt_available,
49
    is_wandb_available,
50
51
    run_hp_search_optuna,
    run_hp_search_ray,
52
    run_hp_search_sigopt,
53
    run_hp_search_wandb,
54
)
55

56
57
# isort: on

58
59
import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
60
import torch.distributed as dist
61
from huggingface_hub import Repository, create_repo
62
63
from packaging import version
from torch import nn
64
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
65
66
from torch.utils.data.distributed import DistributedSampler

67
68
from . import __version__
from .configuration_utils import PretrainedConfig
69
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
70
from .debug_utils import DebugOption, DebugUnderflowOverflow
71
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
72
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
73
from .modelcard import TrainingSummary
74
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
75
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
76
from .optimization import Adafactor, get_scheduler
77
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
78
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
79
80
81
82
83
84
85
86
87
88
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
89
    DistributedLengthGroupedSampler,
90
    DistributedSamplerWithLoop,
91
    DistributedTensorGatherer,
92
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
93
    LabelSmoother,
94
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
95
    SequentialDistributedSampler,
96
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
    distributed_broadcast_scalars,
    distributed_concat,
99
    find_batch_size,
100
    get_module_class_from_name,
101
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
102
103
104
    nested_concat,
    nested_detach,
    nested_numpify,
105
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
108
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
109
110
111
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
112
    EvalLoopOutput,
113
    EvalPrediction,
114
    FSDPOption,
115
    HPSearchBackend,
116
117
    HubStrategy,
    IntervalStrategy,
118
    PredictionOutput,
119
    RemoveColumnsCollator,
120
    ShardedDDPOption,
121
    TrainerMemoryTracker,
122
123
124
    TrainOutput,
    default_compute_objective,
    default_hp_space,
125
    denumpify_detensorize,
126
    enable_full_determinism,
127
    find_executable_batch_size,
128
    get_last_checkpoint,
129
    has_length,
130
    number_of_arguments,
131
    seed_worker,
132
    set_seed,
133
    speed_metrics,
134
)
135
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
136
137
from .utils import (
    CONFIG_NAME,
138
    WEIGHTS_INDEX_NAME,
139
    WEIGHTS_NAME,
140
    can_return_loss,
141
    find_labels,
142
    get_full_repo_name,
143
    is_accelerate_available,
144
145
146
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
147
    is_ipex_available,
148
149
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
150
    is_torch_compile_available,
151
    is_torch_neuroncore_available,
152
153
154
    is_torch_tpu_available,
    logging,
)
155
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
156
157


158
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
159

Sylvain Gugger's avatar
Sylvain Gugger committed
160
DEFAULT_CALLBACKS = [DefaultFlowCallback]
161
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
162

163
164
165
166
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
167

168
169
if is_apex_available():
    from apex import amp
170

171
172
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
173

174
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
175
176
177
178
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

179
if is_fairscale_available():
180
    dep_version_check("fairscale")
181
    import fairscale
182
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
183
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
184
    from fairscale.nn.wrap import auto_wrap
185
186
187
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

188

Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
191
192
193
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
194
195

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
196
197
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
198

199

200
201
202
203
204
205
206
207
skip_first_batches = None
if is_accelerate_available():
    from accelerate import __version__ as accelerate_version

    if version.parse(accelerate_version) >= version.parse("0.16"):
        from accelerate import skip_first_batches


208
209
210
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
211
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
212
213


214
215
216
217
218
219
220
221
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
222
223
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
224
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
225
226

    Args:
227
228
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
229

230
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
231

Sylvain Gugger's avatar
Sylvain Gugger committed
232
233
234
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.
235
236
237
238

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
239
240
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
241
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
242
243
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
244
245
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
246
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
247
248
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
249
250
251
252
253
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
254
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
255
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
256
257
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
258
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
259
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
260
261
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
262
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
263
264
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
265

266
267
268
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
269
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
270
271
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
272
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
273
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
274
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
275

276
277
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
278
279
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
280
281
282
283
284
285
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
286

287
288
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
289
290
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
291
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
292
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
293
294
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
295
296
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
297
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
298
299
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
300
301
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
302

Julien Chaumond's avatar
Julien Chaumond committed
303
304
    """

305
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
306

Julien Chaumond's avatar
Julien Chaumond committed
307
308
    def __init__(
        self,
309
        model: Union[PreTrainedModel, nn.Module] = None,
310
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
311
312
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
313
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
314
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
315
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
316
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
317
        callbacks: Optional[List[TrainerCallback]] = None,
318
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
319
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
320
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
321
        if args is None:
322
323
324
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
325
326
        self.args = args
        # Seed must be set before instantiating the model when using model
327
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
328
        self.hp_name = None
329
        self.deepspeed = None
330
        self.is_in_train = False
331

332
333
334
335
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

336
        # set the correct log level depending on the node
337
        log_level = args.get_process_log_level()
338
339
        logging.set_verbosity(log_level)

340
341
342
        # force device and distributed setup init explicitly
        args._setup_devices

343
344
345
346
347
348
349
350
351
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
352
353
354
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
355
356
357
                    FutureWarning,
                )
            self.model_init = model_init
358

359
360
361
362
363
364
365
366
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

367
368
369
370
371
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

372
373
        # At this stage the model is already loaded
        if getattr(model, "is_loaded_in_8bit", False):
374
375
376
377
378
379
380
381
382
383
384
385
            if getattr(model, "_is_int8_training_enabled", False):
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )
386

387
388
389
390
391
392
393
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
394
395
396
397
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

415
416
417
418
419
420
        self.fsdp = None
        if len(args.fsdp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
421
            if not args.fsdp_config["xla"] and args.local_rank == -1:
422
423
                raise ValueError("Using fsdp only works in distributed training.")

424
425
426
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
427
            # below is the current alternative.
428
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
429
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
430

431
            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
432
433
434
435
436

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
437
438
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
439

440
441
442
443
444
            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
            if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch:
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

            self.forword_prefetch = False
445
            if self.args.fsdp_config.get("forword_prefect", False):
446
447
                self.forword_prefetch = True

448
449
450
451
            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

452
        # one place to sort out whether to place the model on device or not
453
454
455
456
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
457
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
458
        # 4. Sharded DDP - same as MP
459
        # 5. FSDP - same as MP
460
        self.place_model_on_device = args.place_model_on_device
461
462
        if (
            self.is_model_parallel
463
            or args.deepspeed
464
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
465
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
466
            or (self.fsdp is not None)
467
        ):
468
469
            self.place_model_on_device = False

470
471
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
472
473
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
474
        self.tokenizer = tokenizer
475

476
        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
Sylvain Gugger's avatar
Sylvain Gugger committed
477
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
478
479
480

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
481
            self.args._n_gpu = 1
482
483
484
485
486

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
487
        self.compute_metrics = compute_metrics
488
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
489
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
490
491
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
492
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
493
494
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
510
        if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
511
512
513
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
514
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
515
516
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
517
518
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
519
520
521
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
522
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
523

524
525
526
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

527
528
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
529
            self.init_git_repo(at_init=True)
530
531
532
533
534
535
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
            elif args.local_rank != -1:
                dist.barrier()

536
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
537
            os.makedirs(self.args.output_dir, exist_ok=True)
538

539
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
540
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
541

542
543
544
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

545
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
546
547
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")

548
549
550
551
552
553
554
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

555
        self._signature_columns = None
556

557
558
        # Mixed precision setup
        self.use_apex = False
559
560
        self.use_cuda_amp = False
        self.use_cpu_amp = False
561

562
563
564
565
566
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
584

585
586
        if args.fp16 or args.bf16:
            if args.half_precision_backend == "auto":
587
588
589
590
591
592
593
                if args.device == torch.device("cpu"):
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
594
                else:
595
                    args.half_precision_backend = "cuda_amp"
596

597
            logger.info(f"Using {args.half_precision_backend} half precision backend")
598

599
        self.do_grad_scaling = False
600
        if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
601
            # deepspeed and SageMaker Model Parallel manage their own half precision
602
603
            if args.half_precision_backend == "cuda_amp":
                self.use_cuda_amp = True
604
                self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
605
606
607
608
609
610
                #  bf16 does not need grad scaling
                self.do_grad_scaling = self.amp_dtype == torch.float16
                if self.do_grad_scaling:
                    if self.sharded_ddp is not None:
                        self.scaler = ShardedGradScaler()
                    elif self.fsdp is not None:
611
612
613
                        from torch.distributed.fsdp.sharded_grad_scaler import (
                            ShardedGradScaler as FSDPShardedGradScaler,
                        )
614

615
                        self.scaler = FSDPShardedGradScaler()
616
617
                    elif is_torch_tpu_available():
                        from torch_xla.amp import GradScaler
618

619
620
621
                        self.scaler = GradScaler()
                    else:
                        self.scaler = torch.cuda.amp.GradScaler()
622
623
624
            elif args.half_precision_backend == "cpu_amp":
                self.use_cpu_amp = True
                self.amp_dtype = torch.bfloat16
625
626
627
            else:
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
628
629
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
630
631
632
                    )
                self.use_apex = True

633
634
635
636
637
638
639
640
641
642
643
644
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
645
646
647
648
649
650
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

651
652
653
654
655
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
656
        self.control = TrainerControl()
657
658
659
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
660
        self.hp_search_backend = None
661
        self.use_tune_checkpoints = False
662
        default_label_names = find_labels(self.model.__class__)
663
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
664
        self.can_return_loss = can_return_loss(self.model.__class__)
Sylvain Gugger's avatar
Sylvain Gugger committed
665
666
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

667
668
669
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

670
671
672
        # very last
        self._memory_tracker.stop_and_update_metrics()

673
674
        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
675
            raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.")
676

Sylvain Gugger's avatar
Sylvain Gugger committed
677
678
    def add_callback(self, callback):
        """
679
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
680
681

        Args:
682
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
683
684
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
685
686
687
688
689
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
690
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
691

692
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
693
694

        Args:
695
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
696
697
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
698
699

        Returns:
700
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
701
702
703
704
705
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
706
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
707
708

        Args:
709
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
710
711
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
712
713
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
714

Sylvain Gugger's avatar
Sylvain Gugger committed
715
716
717
718
719
720
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

721
    def _set_signature_columns_if_needed(self):
722
723
724
725
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
726
727
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
728

729
730
731
732
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
733
        signature_columns = self._signature_columns
734
735

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
736
        if len(ignored_columns) > 0:
737
            dset_description = "" if description is None else f"in the {description} set"
738
739
740
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
741
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
742
                " you can safely ignore this message."
743
            )
744

745
        columns = [k for k in signature_columns if k in dataset.column_names]
746

747
748
749
750
751
752
753
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
754

755
756
757
758
759
760
761
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
762
        signature_columns = self._signature_columns
763
764
765
766
767
768
769
770
771
772

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

773
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
774
        if self.train_dataset is None or not has_length(self.train_dataset):
775
            return None
776

777
        generator = None
778
        if self.args.world_size <= 1:
779
            generator = torch.Generator()
780
781
782
783
784
785
786
787
788
789
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
790

791
792
        # Build the sampler.
        if self.args.group_by_length:
793
794
795
796
797
798
799
800
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
801
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
802
            if self.args.world_size <= 1:
803
                return LengthGroupedSampler(
804
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
805
                    dataset=self.train_dataset,
806
807
808
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
809
                )
810
811
            else:
                return DistributedLengthGroupedSampler(
812
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
813
                    dataset=self.train_dataset,
814
815
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
816
                    lengths=lengths,
817
                    model_input_name=model_input_name,
818
                    seed=seed,
819
820
821
                )

        else:
822
            if self.args.world_size <= 1:
823
                return RandomSampler(self.train_dataset, generator=generator)
Sylvain Gugger's avatar
Sylvain Gugger committed
824
825
826
827
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
828
829
830
831
832
833
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
834
                    seed=seed,
835
                )
836
            else:
837
                return DistributedSampler(
838
839
840
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
841
                    seed=seed,
842
                )
843
844
845

    def get_train_dataloader(self) -> DataLoader:
        """
846
        Returns the training [`~torch.utils.data.DataLoader`].
847

848
849
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
850
851
852
853
854

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
855

856
        train_dataset = self.train_dataset
857
        data_collator = self.data_collator
858
859
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
860
861
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
862

863
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
864
865
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
866
                    train_dataset,
867
                    batch_size=self._train_batch_size,
868
869
870
871
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
872

873
874
            return DataLoader(
                train_dataset,
875
                batch_size=self._train_batch_size,
876
                collate_fn=data_collator,
877
878
879
880
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

881
882
883
        train_sampler = self._get_train_sampler()

        return DataLoader(
884
            train_dataset,
885
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
886
            sampler=train_sampler,
887
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
888
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
889
            num_workers=self.args.dataloader_num_workers,
890
            pin_memory=self.args.dataloader_pin_memory,
891
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
892
893
        )

894
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            elif self.args.local_rank != -1:
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
917
918
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
919
920
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
921
            )
Lysandre Debut's avatar
Lysandre Debut committed
922

Julien Chaumond's avatar
Julien Chaumond committed
923
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
924
        """
925
        Returns the evaluation [`~torch.utils.data.DataLoader`].
926

927
928
        Subclass and override this method if you want to inject some custom behavior.

929
        Args:
930
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
931
932
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
933
        """
Julien Chaumond's avatar
Julien Chaumond committed
934
935
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
936
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
937
        data_collator = self.data_collator
938

939
940
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
941
942
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
943

944
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
945
946
947
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
948
                    batch_size=self.args.per_device_eval_batch_size,
949
950
951
952
953
954
955
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
956
                collate_fn=data_collator,
957
958
959
960
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

961
        eval_sampler = self._get_eval_sampler(eval_dataset)
962

963
        return DataLoader(
964
            eval_dataset,
965
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
966
            batch_size=self.args.eval_batch_size,
967
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
968
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
969
            num_workers=self.args.dataloader_num_workers,
970
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
971
972
973
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
974
        """
975
        Returns the test [`~torch.utils.data.DataLoader`].
976

977
978
        Subclass and override this method if you want to inject some custom behavior.

979
        Args:
980
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
981
982
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
983
        """
984
985
        data_collator = self.data_collator

986
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
987
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
988
989
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
990

991
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
992
993
994
995
996
997
998
999
1000
1001
1002
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
1003
                collate_fn=data_collator,
1004
1005
1006
1007
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

1008
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
1009

1010
1011
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
1012
            test_dataset,
1013
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1014
            batch_size=self.args.eval_batch_size,
1015
            collate_fn=data_collator,
1016
            drop_last=self.args.dataloader_drop_last,
1017
            num_workers=self.args.dataloader_num_workers,
1018
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1019
        )
Lysandre Debut's avatar
Lysandre Debut committed
1020

1021
    def create_optimizer_and_scheduler(self, num_training_steps: int):
1022
1023
1024
        """
        Setup the optimizer and the learning rate scheduler.

1025
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
1026
1027
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
1028
1029
        """
        self.create_optimizer()
1030
1031
1032
1033
1034
1035
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
1036
1037
1038
1039
1040

    def create_optimizer(self):
        """
        Setup the optimizer.

1041
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
1042
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1043
        """
1044
1045
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

1046
        if self.optimizer is None:
1047
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
1048
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
1049
1050
            optimizer_grouped_parameters = [
                {
1051
1052
1053
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
1054
1055
1056
                    "weight_decay": self.args.weight_decay,
                },
                {
1057
1058
1059
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
1060
1061
1062
                    "weight_decay": 0.0,
                },
            ]
1063
1064
1065

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

1066
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1067
1068
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
1069
1070
                    optim=optimizer_cls,
                    **optimizer_kwargs,
1071
1072
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1073
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1074
1075
1076
1077
1078
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

Stas Bekman's avatar
Stas Bekman committed
1079
                    skipped = 0
1080
                    for module in opt_model.modules():
1081
                        if isinstance(module, nn.Embedding):
1082
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
Stas Bekman's avatar
Stas Bekman committed
1083
                            print(f"skipped {module}: {skipped/2**20}M params")
1084
1085
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
Stas Bekman's avatar
Stas Bekman committed
1086
                    print(f"skipped: {skipped/2**20}M params")
Sylvain Gugger's avatar
Sylvain Gugger committed
1087

Sylvain Gugger's avatar
Sylvain Gugger committed
1088
1089
1090
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1091
1092
        return self.optimizer

1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
1103
1104
1105
1106
1107
1108
1109
1110

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

1111
        optimizer_kwargs = {"lr": args.learning_rate}
1112

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
        elif args.optim == OptimizerNames.ADAMW_TORCH:
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1130
1131
1132
1133
1134
1135
1136
1137
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1138
1139
1140
1141
1142
1143
1144
1145
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1146
1147
1148
1149
1150
1151
1152
1153
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
                        "compensation_buffer_dtype": getattr(
                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
                        ),
                    }
                )
            except ImportError:
                raise ValueError("Please install https://github.com/pytorch/torchdistx")
1174
1175
1176
1177
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1178
1179
1180
1181
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1182
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1183
        """
1184
1185
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1186
1187
1188
1189

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1190
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1191
1192
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1193
                optimizer=self.optimizer if optimizer is None else optimizer,
1194
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1195
                num_training_steps=num_training_steps,
1196
            )
1197
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1198

1199
    def num_examples(self, dataloader: DataLoader) -> int:
1200
        """
1201
1202
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1203
        """
1204
        try:
1205
1206
1207
1208
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1209
1210
1211
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1212

1213
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1214
        """HP search setup code"""
1215
1216
        self._trial = trial

1217
1218
        if self.hp_search_backend is None or trial is None:
            return
1219
1220
1221
1222
1223
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1224
1225
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1226
1227
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1228

1229
1230
        for key, value in params.items():
            if not hasattr(self.args, key):
1231
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1232
1233
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1234
                )
1235
                continue
1236
1237
1238
1239
1240
1241
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1242
            logger.info(f"Trial: {trial.params}")
1243
1244
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1245
1246
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1247
1248
        if self.args.deepspeed:
            # Rebuild the deepspeed config to reflect the updated training parameters
1249
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1250

1251
1252
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1253

1254
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1255
1256
        if self.hp_search_backend is None or trial is None:
            return
1257
        self.objective = self.compute_objective(metrics.copy())
1258
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1259
1260
            import optuna

1261
            trial.report(self.objective, step)
1262
            if trial.should_prune():
1263
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1264
1265
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1266
1267
            from ray import tune

1268
            if self.control.should_save:
1269
                self._tune_save_checkpoint()
1270
1271
            tune.report(objective=self.objective, **metrics)

1272
    def _tune_save_checkpoint(self):
1273
1274
        from ray import tune

1275
1276
        if not self.use_tune_checkpoints:
            return
1277
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1278
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1279
            self.save_model(output_dir, _internal_call=True)
1280
            if self.args.should_save:
1281
1282
1283
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1284

1285
    def call_model_init(self, trial=None):
1286
        model_init_argcount = number_of_arguments(self.model_init)
1287
1288
1289
1290
1291
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1292
1293
1294
1295
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1296
1297
1298

        return model

1299
1300
1301
1302
1303
1304
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            example_batch = next(iter(dataloader))
1305
            example_batch = self._prepare_inputs(example_batch)
1306
1307
            try:
                jit_model = model.eval()
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
                with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
1325
                jit_model = torch.jit.freeze(jit_model)
1326
1327
1328
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
1329
                model = jit_model
1330
1331
1332
                self.use_cpu_amp = False
                self.use_cuda_amp = False
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
1333
1334
1335
1336
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1337
1338
1339
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1340
1341
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1342
1343
1344
1345
1346
1347
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
1348
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
1349
            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
1350
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
1351
1352
1353
        else:
            if not model.training:
                model.train()
1354
1355
1356
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1357
1358
1359

        return model

1360
    def _wrap_model(self, model, training=True, dataloader=None):
1361
1362
        if self.args.torch_compile:
            model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
1363

1364
1365
1366
1367
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

Sylvain Gugger's avatar
Sylvain Gugger committed
1368
1369
1370
1371
1372
1373
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1374
1375
        # already initialized its own DDP and AMP
        if self.deepspeed:
1376
            return self.deepspeed
1377

1378
1379
1380
1381
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1382
1383
1384
1385
1386
1387
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
1388
            model = nn.DataParallel(model)
1389

1390
        if self.args.jit_mode_eval:
1391
            start_time = time.time()
1392
            model = self.torch_jit_model_eval(model, dataloader, training)
1393
            self.jit_compilation_time = round(time.time() - start_time, 4)
1394

1395
1396
1397
1398
1399
1400
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1401
1402
1403
1404
1405
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1406
                mixed_precision = self.args.fp16 or self.args.bf16
1407
1408
1409
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1410
1411
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1412
                self.model = model = FullyShardedDDP(
1413
1414
1415
1416
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1417
                ).to(self.args.device)
1418
        # Distributed training using PyTorch FSDP
1419
        elif self.fsdp is not None:
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
            if not self.args.fsdp_config["xla"]:
                # PyTorch FSDP!
                from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

                if FSDPOption.OFFLOAD in self.args.fsdp:
                    cpu_offload = CPUOffload(offload_params=True)
                else:
                    cpu_offload = CPUOffload(offload_params=False)
1430

1431
                auto_wrap_policy = None
1432

1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
                if FSDPOption.AUTO_WRAP in self.args.fsdp:
                    if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                        auto_wrap_policy = functools.partial(
                            size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                        )
                    elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                        transformer_cls_to_wrap = set()
                        for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                            transformer_cls = get_module_class_from_name(model, layer_class)
                            if transformer_cls is None:
                                raise Exception("Could not find the transformer layer class to wrap in the model.")
                            else:
                                transformer_cls_to_wrap.add(transformer_cls)
                        auto_wrap_policy = functools.partial(
                            transformer_auto_wrap_policy,
                            # Transformer layer class to wrap
                            transformer_layer_cls=transformer_cls_to_wrap,
                        )
                mixed_precision_policy = None
                dtype = None
                if self.args.fp16:
                    dtype = torch.float16
                elif self.args.bf16:
                    dtype = torch.bfloat16
                if dtype is not None:
                    mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
                if type(model) != FSDP:
                    # XXX: Breaking the self.model convention but I see no way around it for now.
                    self.model = model = FSDP(
                        model,
                        sharding_strategy=self.fsdp,
                        cpu_offload=cpu_offload,
                        auto_wrap_policy=auto_wrap_policy,
                        mixed_precision=mixed_precision_policy,
                        device_id=self.args.device,
                        backward_prefetch=self.backward_prefetch,
                        forward_prefetch=self.forword_prefetch,
                        limit_all_gathers=self.limit_all_gathers,
                    )
            else:
                try:
                    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                    from torch_xla.distributed.fsdp import checkpoint_module
                    from torch_xla.distributed.fsdp.wrap import (
                        size_based_auto_wrap_policy,
                        transformer_auto_wrap_policy,
                    )
                except ImportError:
                    raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
                auto_wrap_policy = None
                auto_wrapper_callable = None
1484
                if self.args.fsdp_config["fsdp_min_num_params"] > 0:
1485
                    auto_wrap_policy = functools.partial(
1486
                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
1487
                    )
1488
1489
1490
1491
1492
1493
1494
1495
                elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                    transformer_cls_to_wrap = set()
                    for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                        transformer_cls = get_module_class_from_name(model, layer_class)
                        if transformer_cls is None:
                            raise Exception("Could not find the transformer layer class to wrap in the model.")
                        else:
                            transformer_cls_to_wrap.add(transformer_cls)
1496
1497
1498
                    auto_wrap_policy = functools.partial(
                        transformer_auto_wrap_policy,
                        # Transformer layer class to wrap
1499
                        transformer_layer_cls=transformer_cls_to_wrap,
1500
                    )
1501
1502
1503
1504
1505
1506
1507
                fsdp_kwargs = self.args.xla_fsdp_config
                if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                    # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                    def auto_wrapper_callable(m, *args, **kwargs):
                        return FSDP(checkpoint_module(m), *args, **kwargs)

                # Wrap the base model with an outer FSDP wrapper
1508
                self.model = model = FSDP(
1509
1510
                    model,
                    auto_wrap_policy=auto_wrap_policy,
1511
1512
                    auto_wrapper_callable=auto_wrapper_callable,
                    **fsdp_kwargs,
1513
                )
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523

                # Patch `xm.optimizer_step` should not reduce gradients in this case,
                # as FSDP does not need gradient reduction over sharded parameters.
                def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                    loss = optimizer.step(**optimizer_args)
                    if barrier:
                        xm.mark_step()
                    return loss

                xm.optimizer_step = patched_optimizer_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1524
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1525
1526
1527
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1528
        elif self.args.local_rank != -1:
1529
            kwargs = {}
1530
            if self.args.ddp_find_unused_parameters is not None:
1531
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1532
1533
1534
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1535
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1536
            else:
1537
1538
1539
1540
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1541
            if is_torch_neuroncore_available():
1542
                return model
1543
            model = nn.parallel.DistributedDataParallel(
1544
                model,
1545
1546
                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
1547
                **kwargs,
1548
1549
1550
1551
            )

        return model

1552
1553
    def train(
        self,
1554
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1555
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1556
        ignore_keys_for_eval: Optional[List[str]] = None,
1557
        **kwargs,
1558
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1559
1560
1561
1562
        """
        Main training entry point.

        Args:
1563
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1564
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1565
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1566
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1567
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1568
                The trial run or the hyperparameter dictionary for hyperparameter search.
1569
            ignore_keys_for_eval (`List[str]`, *optional*)
1570
1571
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1572
1573
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1574
        """
1575
1576
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1577
1578
1579
1580

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1581
1582
        args = self.args

1583
1584
        self.is_in_train = True

1585
1586
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1587
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1588
            self._move_model_to_device(self.model, args.device)
1589

1590
1591
1592
1593
1594
1595
1596
1597
1598
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1599
1600
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1601
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1602

1603
        # Model re-init
1604
        model_reloaded = False
1605
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1606
            # Seed must be set before instantiating the model when using model_init.
1607
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1608
1609
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1610
1611
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1612

1613
        # Load potential model checkpoint
1614
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1615
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1616
            if resume_from_checkpoint is None:
1617
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1618

1619
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None:
1620
            self._load_from_checkpoint(resume_from_checkpoint)
1621

1622
1623
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1624
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1625
                self._move_model_to_device(self.model, args.device)
1626
1627
            self.model_wrapped = self.model

1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self._train_batch_size = batch_size
1642
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1643
        train_dataloader = self.get_train_dataloader()
1644
1645
1646
1647
1648

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1649
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1650
1651
1652
1653
1654

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1655
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1656
            num_examples = self.num_examples(train_dataloader)
1657
1658
1659
1660
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1661
                )
1662
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1663
1664
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1665
            else:
1666
1667
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1668
1669
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1670
            max_steps = args.max_steps
1671
1672
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1673
            num_update_steps_per_epoch = max_steps
1674
            num_examples = total_train_batch_size * args.max_steps
1675
            num_train_samples = args.max_steps * total_train_batch_size
1676
1677
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1678
1679
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1680
            )
Julien Chaumond's avatar
Julien Chaumond committed
1681

1682
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1683
1684
1685
1686
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1687
1688
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1689
1690
1691
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1692

1693
        delay_optimizer_creation = (
1694
1695
1696
1697
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1698
        )
1699
        if args.deepspeed:
1700
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
1701
1702
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
1703
1704
1705
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
1706
1707
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
1708
        elif not delay_optimizer_creation:
1709
1710
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1711
        self.state = TrainerState()
1712
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1713

1714
1715
1716
1717
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1718
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1719

1720
1721
1722
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1723
1724
1725
1726
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1727
1728
1729
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1730
1731
1732
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1733
1734
        # important: at this point:
        # self.model         is the Transformers Model
1735
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1736

Julien Chaumond's avatar
Julien Chaumond committed
1737
1738
        # Train!
        logger.info("***** Running training *****")
1739
1740
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
1741
        logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
1742
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
1743
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1744
        logger.info(f"  Total optimization steps = {max_steps}")
1745
1746
1747
        logger.info(
            f"  Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
        )
Julien Chaumond's avatar
Julien Chaumond committed
1748

1749
        self.state.epoch = 0
1750
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1751
1752
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1753
        steps_trained_progress_bar = None
1754

Julien Chaumond's avatar
Julien Chaumond committed
1755
        # Check if continuing training from a checkpoint
1756
        if resume_from_checkpoint is not None and os.path.isfile(
1757
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1758
        ):
1759
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1760
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1761
            if not args.ignore_data_skip:
1762
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1763
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1764
1765
            else:
                steps_trained_in_current_epoch = 0
1766
1767

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1768
1769
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1770
            if not args.ignore_data_skip:
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
                if skip_first_batches is None:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
                        " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
                        " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
                        " training on data already seen by your model."
                    )
                else:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch."
                    )
                if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
1785
1786
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1787

Sylvain Gugger's avatar
Sylvain Gugger committed
1788
1789
1790
1791
1792
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1793
1794
1795
1796
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
1797
1798
1799
1800
1801
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1802
1803
1804
1805
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1806
1807
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1808

1809
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1810
        tr_loss = torch.tensor(0.0).to(args.device)
1811
1812
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1813
        self._globalstep_last_logged = self.state.global_step
Julien Chaumond's avatar
Julien Chaumond committed
1814
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1815

1816
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1817

1818
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1819
        if not args.ignore_data_skip:
1820
            for epoch in range(epochs_trained):
1821
1822
1823
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
1824
                if is_torch_less_than_1_11 or not is_random_sampler:
1825
1826
1827
1828
1829
1830
1831
1832
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1833

1834
        for epoch in range(epochs_trained, num_train_epochs):
1835
1836
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1837
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1838
                train_dataloader.dataset.set_epoch(epoch)
1839

1840
            if is_torch_tpu_available():
1841
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1842
                epoch_iterator = parallel_loader
1843
            else:
1844
                epoch_iterator = train_dataloader
1845

1846
            # Reset the past mems state at the beginning of each epoch if necessary.
1847
            if args.past_index >= 0:
1848
1849
                self._past = None

1850
            steps_in_epoch = (
1851
1852
1853
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1854
            )
1855
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1856

1857
1858
1859
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1860
            rng_to_sync = False
1861
            steps_skipped = 0
1862
1863
            if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
1864
                steps_skipped = steps_trained_in_current_epoch
1865
1866
1867
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

1868
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1869
            for step, inputs in enumerate(epoch_iterator):
1870
1871
1872
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False
Julien Chaumond's avatar
Julien Chaumond committed
1873
1874
1875
1876

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1877
1878
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1879
1880
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1881
                    continue
1882
1883
1884
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1885

1886
1887
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1888

1889
                if (
1890
1891
1892
                    ((step + 1) % args.gradient_accumulation_steps != 0)
                    and args.local_rank != -1
                    and args._no_sync_in_gradient_accumulation
1893
                ):
1894
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1895
                    with model.no_sync():
1896
                        tr_loss_step = self.training_step(model, inputs)
1897
                else:
1898
1899
                    tr_loss_step = self.training_step(model, inputs)

1900
1901
1902
1903
1904
1905
1906
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1907
1908
1909
                else:
                    tr_loss += tr_loss_step

1910
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1911

1912
1913
1914
1915
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

1916
                if (step + 1) % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1917
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1918
                    steps_in_epoch <= args.gradient_accumulation_steps
1919
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1920
                ):
1921
                    # Gradient clipping
1922
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
1923
1924
                        # deepspeed does its own clipping

1925
                        if self.do_grad_scaling:
1926
1927
1928
1929
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1930
1931
1932
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1933
1934
1935
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1936
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1937
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1938
1939
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1940
                            model.clip_grad_norm_(args.max_grad_norm)
1941
1942
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
1943
                            nn.utils.clip_grad_norm_(
1944
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
1945
                                args.max_grad_norm,
1946
1947
1948
                            )

                    # Optimizer step
1949
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
1950
                    if self.deepspeed:
1951
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
1952
                    elif is_torch_tpu_available():
1953
1954
1955
1956
1957
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1958
                    elif self.do_grad_scaling:
1959
                        scale_before = self.scaler.get_scale()
1960
                        self.scaler.step(self.optimizer)
1961
                        self.scaler.update()
1962
1963
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1964
                    else:
1965
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
1966

1967
                    if optimizer_was_run and not self.deepspeed:
1968
1969
                        self.lr_scheduler.step()

Julien Chaumond's avatar
Julien Chaumond committed
1970
                    model.zero_grad()
1971
                    self.state.global_step += 1
1972
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
1973
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1974

1975
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
1976
1977
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1978

Sylvain Gugger's avatar
Sylvain Gugger committed
1979
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1980
                    break
1981
1982
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1983
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
1984
1985
1986
1987
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
1988

1989
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
1990
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
1991

1992
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
1993
1994
1995
1996
1997
1998
1999
2000
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2001
            if self.control.should_training_stop:
2002
                break
Julien Chaumond's avatar
Julien Chaumond committed
2003

2004
        if args.past_index and hasattr(self, "_past"):
2005
2006
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2007
2008

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
2009
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
2010
2011
2012
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
2013
            elif args.local_rank != -1:
2014
                dist.barrier()
2015
2016
            elif is_sagemaker_mp_enabled():
                smp.barrier()
2017

2018
            self._load_best_model()
2019

2020
2021
2022
2023
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

2024
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
2025
2026
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
2027
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2028

2029
        self.is_in_train = False
2030

2031
2032
        self._memory_tracker.stop_and_update_metrics(metrics)

2033
2034
        self.log(metrics)

raghavanone's avatar
raghavanone committed
2035
2036
2037
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

2038
2039
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
raghavanone's avatar
raghavanone committed
2040
2041
2042
2043
2044
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

2045
2046
2047
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
2048

raghavanone's avatar
raghavanone committed
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

2069
2070
2071
2072
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

2073
2074
2075
        if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
            os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        ):
2076
2077
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

2078
        logger.info(f"Loading model from {resume_from_checkpoint}.")
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089

        if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
            config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

2090
        if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
2091
            # If the model is on the GPU, it still works!
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
                    state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
2115
2116
2117
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
2118
2119
                # release memory
                del state_dict
2120
                self._issue_warnings_after_load(load_result)
2121
2122
        else:
            # We load the sharded checkpoint
2123
2124
            load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
            if not is_sagemaker_mp_enabled():
2125
                self._issue_warnings_after_load(load_result)
2126
2127
2128
2129

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
2130
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2131
2132
        if os.path.exists(best_model_path):
            if self.deepspeed:
2133
2134
2135
2136
2137
                if self.model_wrapped is not None:
                    # this removes the pre-hooks from the previous engine
                    self.model_wrapped.destroy()
                    self.model_wrapped = None

2138
                # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
2139
2140
2141
2142
2143
                deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
                    self,
                    num_training_steps=self.args.max_steps,
                    resume_from_checkpoint=self.state.best_model_checkpoint,
                )
2144
2145
2146
2147
2148
2149
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
                self.optimizer = optimizer
                self.lr_scheduler = lr_scheduler
            else:
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
                        state_dict = torch.load(best_model_path, map_location="cpu")
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
                else:
                    # We load the model state dict on the CPU to avoid an OOM error.
                    state_dict = torch.load(best_model_path, map_location="cpu")
                    # If the model is on the GPU, it still works!
2170
2171
2172
                    # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                    # which takes *args instead of **kwargs
                    load_result = model.load_state_dict(state_dict, False)
2173
                if not is_sagemaker_mp_enabled():
2174
                    self._issue_warnings_after_load(load_result)
2175
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2176
2177
2178
2179
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2180
                self._issue_warnings_after_load(load_result)
2181
2182
2183
2184
2185
2186
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2187
    def _issue_warnings_after_load(self, load_result):
2188
        if len(load_result.missing_keys) != 0:
2189
2190
2191
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2192
2193
                self.model.tie_weights()
            else:
2194
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2195
        if len(load_result.unexpected_keys) != 0:
2196
2197
2198
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2199

2200
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2201
        if self.control.should_log:
2202
2203
2204
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2205
            logs: Dict[str, float] = {}
2206
2207
2208
2209

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2210
2211
2212
            # reset tr_loss to zero
            tr_loss -= tr_loss

2213
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2214
            logs["learning_rate"] = self._get_learning_rate()
2215

2216
            self._total_loss_scalar += tr_loss_scalar
2217
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2218
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2219
2220
2221
2222
2223

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2224
2225
2226
2227
2228
2229
2230
2231
2232
            if isinstance(self.eval_dataset, dict):
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
                    metrics = self.evaluate(
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2233
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2234

Sylvain Gugger's avatar
Sylvain Gugger committed
2235
2236
2237
2238
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2239
2240
2241
2242
2243
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2244
2245
2246
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2247
            if not os.path.isfile(rng_file):
2248
                logger.info(
2249
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2250
2251
2252
2253
2254
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2255
            if not os.path.isfile(rng_file):
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
            if self.args.local_rank != -1:
                torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
            else:
2270
2271
2272
                try:
                    torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
                except Exception as e:
2273
                    logger.info(
2274
2275
2276
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2277
2278
2279
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2280
    def _save_checkpoint(self, model, trial, metrics=None):
2281
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2282
        # want to save except FullyShardedDDP.
2283
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2284

2285
        # Save model checkpoint
2286
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2287

raghavanone's avatar
raghavanone committed
2288
        if self.hp_search_backend is None and trial is None:
2289
            self.store_flos()
2290

raghavanone's avatar
raghavanone committed
2291
        run_dir = self._get_output_dir(trial=trial)
2292
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2293
        self.save_model(output_dir, _internal_call=True)
2294
        if self.deepspeed:
2295
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2296
            # config `stage3_gather_16bit_weights_on_model_save` is True
2297
            self.deepspeed.save_checkpoint(output_dir)
2298
2299

        # Save optimizer and scheduler
2300
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2301
            self.optimizer.consolidate_state_dict()
2302

2303
2304
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2305
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2306
            with warnings.catch_warnings(record=True) as caught_warnings:
2307
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2308
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2309
        elif is_sagemaker_mp_enabled():
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2325
        elif self.args.should_save and not self.deepspeed:
2326
            # deepspeed.save_checkpoint above saves model/optim/sched
2327
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2328
            with warnings.catch_warnings(record=True) as caught_warnings:
2329
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2330
            reissue_pt_warnings(caught_warnings)
2331
            if self.do_grad_scaling:
2332
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2333
2334

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2335
        if metrics is not None and self.args.metric_for_best_model is not None:
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2351
        if self.args.should_save:
2352
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2353

2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.local_rank == -1:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2370
2371
2372
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2373

2374
        if self.args.world_size <= 1:
2375
2376
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2377
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2378

2379
2380
2381
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2382
        # Maybe delete some older checkpoints.
2383
        if self.args.should_save:
2384
2385
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2386
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2387
        """If optimizer and scheduler states exist, load them."""
2388
        if checkpoint is None:
2389
2390
            return

2391
        if self.deepspeed:
2392
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2393
2394
            return

2395
2396
2397
2398
2399
2400
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2401
2402
2403
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2404
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2405
                with warnings.catch_warnings(record=True) as caught_warnings:
2406
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2407
2408
2409
2410
2411
2412
2413
2414
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2415
                map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
2416
                if is_sagemaker_mp_enabled():
2417
2418
2419
2420
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2421

2422
2423
2424
2425
2426
2427
2428
2429
2430
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2431
2432
2433
2434
2435
2436

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
                    self.optimizer.load_state_dict(
                        torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2437
                with warnings.catch_warnings(record=True) as caught_warnings:
2438
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2439
                reissue_pt_warnings(caught_warnings)
2440
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2441
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2442

2443
2444
2445
2446
2447
2448
2449
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2450
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2451
        **kwargs,
2452
2453
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2454
2455
2456
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2457

2458
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2459

Sylvain Gugger's avatar
Sylvain Gugger committed
2460
2461
2462
2463
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2464
2465

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2466

2467
        Args:
2468
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2469
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2470
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2471
2472
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2473
2474
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2475
            n_trials (`int`, *optional*, defaults to 100):
2476
                The number of trial runs to test.
2477
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2478
2479
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2480
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2481
2482
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2483
2484
2485
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2486
2487
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2488

Sylvain Gugger's avatar
Sylvain Gugger committed
2489
2490
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2491
2492
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2493
2494

        Returns:
2495
            [`trainer_utils.BestRun`]: All the information about the best run.
2496
2497
2498
2499
2500
2501
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2502
2503
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2504
                    "To install sigopt run `pip install sigopt`."
2505
2506
2507
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2508
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2509
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2510
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2511
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2512
            )
2513
2514
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2515
2516
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2517
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2518
2519
2520
2521
2522
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2523
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2524
        self.hp_name = hp_name
2525
2526
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2527
2528
2529
2530
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2531
            HPSearchBackend.WANDB: run_hp_search_wandb,
2532
2533
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2534
2535
2536
2537

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2538
    def log(self, logs: Dict[str, float]) -> None:
2539
        """
2540
        Log `logs` on the various objects watching training.
2541
2542
2543
2544

        Subclass and override this method to inject custom behavior.

        Args:
2545
            logs (`Dict[str, float]`):
2546
2547
                The values to log.
        """
2548
        if self.state.epoch is not None:
2549
            logs["epoch"] = round(self.state.epoch, 2)
2550

2551
2552
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2553
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2554

2555
2556
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2557
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2558
        """
2559
2560
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2561
2562
2563
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
2564
            kwargs = {"device": self.args.device}
2565
2566
            if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)):
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2567
2568
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
2569
                kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()})
2570
2571
2572
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2573
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2574
        """
2575
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2576
2577
        handling potential state.
        """
2578
        inputs = self._prepare_input(inputs)
2579
2580
2581
2582
2583
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2584
2585
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2586

2587
2588
        return inputs

2589
2590
2591
2592
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
2593
        return self.autocast_smart_context_manager()
2594

2595
    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
2596
        """
2597
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2598
2599
        arguments, depending on the situation.
        """
2600
        if self.use_cuda_amp or self.use_cpu_amp:
2601
            if is_torch_greater_or_equal_than_1_10:
2602
                ctx_manager = (
2603
                    torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2604
                    if self.use_cpu_amp
2605
                    else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2606
                )
2607
            else:
2608
                ctx_manager = torch.cuda.amp.autocast()
2609
2610
2611
2612
2613
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2614
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2615
        """
2616
        Perform a training step on a batch of inputs.
2617
2618
2619
2620

        Subclass and override to inject custom behavior.

        Args:
2621
            model (`nn.Module`):
2622
                The model to train.
2623
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2624
2625
2626
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2627
                argument `labels`. Check your model's documentation for all accepted arguments.
2628
2629

        Return:
2630
            `torch.Tensor`: The tensor with training loss on this batch.
2631
2632
        """
        model.train()
2633
        inputs = self._prepare_inputs(inputs)
2634

Sylvain Gugger's avatar
Sylvain Gugger committed
2635
        if is_sagemaker_mp_enabled():
2636
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2637
2638
            return loss_mb.reduce_mean().detach().to(self.args.device)

2639
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2640
            loss = self.compute_loss(model, inputs)
2641

Julien Chaumond's avatar
Julien Chaumond committed
2642
2643
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2644

2645
2646
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
2647
2648
            loss = loss / self.args.gradient_accumulation_steps

2649
        if self.do_grad_scaling:
2650
            self.scaler.scale(loss).backward()
2651
        elif self.use_apex:
2652
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2653
                scaled_loss.backward()
2654
        elif self.deepspeed:
2655
2656
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2657
2658
2659
        else:
            loss.backward()

2660
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
2661

2662
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2663
2664
2665
2666
2667
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2668
2669
2670
2671
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2672
2673
        outputs = model(**inputs)
        # Save past state if it exists
2674
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2675
2676
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2677

2678
        if labels is not None:
2679
2680
2681
2682
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2683
        else:
2684
2685
2686
2687
2688
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2689
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2690
2691
2692
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2693

2694
2695
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2696
2697
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2698
        """
2699
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2700

2701
2702
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2703
        Whether or not this process is the global main process (when training in a distributed fashion on several
2704
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2705
        """
2706
2707
2708
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2709
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2710
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2711
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2712

Sylvain Gugger's avatar
Sylvain Gugger committed
2713
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2714
        """
2715
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2716

2717
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2718
        """
2719
2720
2721
2722

        if output_dir is None:
            output_dir = self.args.output_dir

2723
        if is_torch_tpu_available():
2724
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2725
2726
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2727
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2728
            state_dict = self.model_wrapped.state_dict()
2729
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2730
                self._save(output_dir, state_dict=state_dict)
2731
2732
2733
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2734
        elif (
2735
2736
2737
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2738
2739
        ):
            state_dict = self.model.state_dict()
2740

2741
            if self.args.should_save:
2742
                self._save(output_dir, state_dict=state_dict)
2743
2744
        elif self.deepspeed:
            # this takes care of everything as long as we aren't under zero3
2745
            if self.args.should_save:
2746
2747
2748
2749
2750
2751
2752
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2753
                if self.args.should_save:
2754
2755
2756
2757
2758
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2759
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2760
2761
                # if false it will not be saved.
                # This must be called on all ranks
2762
                if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
2763
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2764
2765
2766
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2767
2768
                    )
                    self.deepspeed.save_checkpoint(output_dir)
2769

2770
        elif self.args.should_save:
2771
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2772

Sylvain Gugger's avatar
Sylvain Gugger committed
2773
2774
2775
2776
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2777
2778
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2779
        logger.info(f"Saving model checkpoint to {output_dir}")
2780
2781
2782

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2783
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2784
2785
2786
2787

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2788
        if not isinstance(self.model, PreTrainedModel):
2789
2790
2791
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2792
                    is_main_process=self.args.should_save,
2793
2794
2795
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2796
2797
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2798
2799
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2800
        else:
2801
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2802
        if self.tokenizer is not None and self.args.should_save:
2803
            self.tokenizer.save_pretrained(output_dir)
2804

2805
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2806
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2807
2808
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2809
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2810
2811
2812
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2813
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2814
2815
2816
                if state_dict is None:
                    state_dict = self.model.state_dict()
                unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
2817
2818
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2819
2820
                if state_dict is None:
                    state_dict = self.model.state_dict()
2821
                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2822
        else:
2823
            self.model.save_pretrained(output_dir, state_dict=state_dict)
2824
        if self.tokenizer is not None:
2825
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2826
2827

        # Good practice: save your training arguments together with the trained model
2828
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2829

2830
    def store_flos(self):
2831
        # Storing the number of floating-point operations that went into the model
2832
        if self.args.local_rank != -1:
2833
2834
2835
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2836
2837
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2838
            self.state.total_flos += self.current_flos
2839
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2840

2841
2842
2843
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2844
2845
        ordering_and_checkpoint_path = []

2846
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2847
2848
2849
2850
2851

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2852
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2853
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2854
2855
2856
2857
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2858
2859
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2860
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2861
2862
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2863
2864
        return checkpoints_sorted

2865
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2866
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2867
2868
2869
            return

        # Check if we should delete older checkpoint(s)
2870
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2871
2872
2873
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2874
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2885
2886
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2887
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
2888
            shutil.rmtree(checkpoint, ignore_errors=True)
Julien Chaumond's avatar
Julien Chaumond committed
2889

2890
    def evaluate(
2891
2892
2893
2894
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2895
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2896
        """
2897
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2898

Sylvain Gugger's avatar
Sylvain Gugger committed
2899
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2900
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2901

2902
2903
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2904
        Args:
2905
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2906
2907
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
2908
                method.
2909
            ignore_keys (`Lst[str]`, *optional*):
2910
2911
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2912
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2913
2914
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2915

Julien Chaumond's avatar
Julien Chaumond committed
2916
        Returns:
2917
2918
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2919
        """
2920
2921
2922
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2923
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2924
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2925

2926
2927
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2928
2929
2930
2931
2932
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
2933
            ignore_keys=ignore_keys,
2934
            metric_key_prefix=metric_key_prefix,
2935
        )
Lysandre Debut's avatar
Lysandre Debut committed
2936

2937
        total_batch_size = self.args.eval_batch_size * self.args.world_size
2938
2939
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
2940
2941
2942
2943
2944
2945
2946
2947
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2948

2949
        self.log(output.metrics)
2950

2951
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
2952
2953
2954
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
2955
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
2956
2957
2958

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
2959
2960
        return output.metrics

2961
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
2962
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
2963
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2964
        """
2965
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2966

Sylvain Gugger's avatar
Sylvain Gugger committed
2967
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
2968
        will also return metrics, like in `evaluate()`.
2969
2970

        Args:
2971
2972
2973
2974
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
            ignore_keys (`Lst[str]`, *optional*):
2975
2976
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2977
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
2978
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
2979
                "test_bleu" if the prefix is "test" (default)
2980

2981
2982
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
2983
2984
2985
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
2986

2987
        </Tip>
2988

2989
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
2990

2991
2992
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
2993
2994
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
2995
        """
2996
2997
2998
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2999
        test_dataloader = self.get_test_dataloader(test_dataset)
3000
        start_time = time.time()
3001

3002
3003
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3004
3005
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
3006
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3007
3008
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3009
3010
3011
3012
3013
3014
3015
3016
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3017

3018
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
3019
3020
        self._memory_tracker.stop_and_update_metrics(output.metrics)

3021
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
3022

3023
    def evaluation_loop(
3024
3025
3026
3027
3028
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
3029
        metric_key_prefix: str = "eval",
3030
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3031
        """
3032
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
3033
3034
3035

        Works both with or without labels.
        """
3036
3037
3038
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
3039

3040
        # if eval is called w/o train init deepspeed here
3041
        if args.deepspeed and not self.deepspeed:
3042
3043
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
3044
3045
3046
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
3047
3048
3049
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
3050

3051
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
3052

3053
3054
3055
3056
3057
3058
3059
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3060

3061
        batch_size = self.args.eval_batch_size
3062

3063
        logger.info(f"***** Running {description} *****")
3064
        if has_length(dataloader):
3065
3066
3067
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
3068
        logger.info(f"  Batch size = {batch_size}")
3069

Julien Chaumond's avatar
Julien Chaumond committed
3070
3071
        model.eval()

3072
3073
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
3074
        eval_dataset = getattr(dataloader, "dataset", None)
3075

3076
        if is_torch_tpu_available():
3077
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3078

3079
        if args.past_index >= 0:
3080
            self._past = None
3081

3082
3083
3084
3085
3086
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
3087
3088
        inputs_host = None

3089
3090
3091
3092
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
3093
        all_inputs = None
3094
3095
3096
3097
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
3098
        for step, inputs in enumerate(dataloader):
3099
3100
3101
3102
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
3103
3104
3105
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
3106
3107

            # Prediction step
3108
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3109
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3110

3111
3112
3113
            if is_torch_tpu_available():
                xm.mark_step()

3114
            # Update containers on host
3115
            if loss is not None:
3116
                losses = self._nested_gather(loss.repeat(batch_size))
3117
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
3118
            if labels is not None:
3119
3120
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
3121
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3122
3123
3124
3125
3126
3127
3128
3129
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3130
3131
3132
3133
3134
3135
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
3136
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
3137

3138
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3139
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3140
3141
3142
3143
3144
3145
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3146
3147
3148
3149
3150
3151
3152
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3153
3154
3155
3156
3157
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3158
3159

                # Set back to None to begin a new accumulation
3160
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3161

3162
        if args.past_index and hasattr(self, "_past"):
3163
3164
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3165

3166
        # Gather all remaining tensors and put them back on the CPU
3167
3168
3169
3170
3171
3172
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3173
3174
3175
3176
3177
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
3178
3179
3180
3181
3182
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
3183
        if has_length(eval_dataset):
3184
            num_samples = len(eval_dataset)
3185
3186
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3187
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3188
3189
            num_samples = eval_dataset.num_examples
        else:
3190
3191
3192
3193
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3194
3195
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3196
3197
3198
3199
3200
3201
3202
3203
3204

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
3205
3206
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
3207
3208
3209

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3210
3211
3212
3213
3214
3215
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3216
3217
        else:
            metrics = {}
3218

3219
3220
3221
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3222
3223
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3224
3225
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
3226

3227
        # Prefix all keys with metric_key_prefix + '_'
3228
        for key in list(metrics.keys()):
3229
3230
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3231

3232
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3233

3234
    def _nested_gather(self, tensors, name=None):
3235
3236
3237
3238
3239
3240
3241
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3242
3243
            if name is None:
                name = "nested_gather"
3244
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3245
3246
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3247
3248
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)
3249
        return tensors
3250

3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
3273
3274
3275
        # When extracting XLA graphs for compilation, max_size is 0,
        # so use inequality to avoid errors.
        if tensor.shape[1] >= max_size:
3276
3277
3278
3279
3280
3281
3282
3283
3284
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3285

3286
    def prediction_step(
3287
3288
3289
3290
3291
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3292
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3293
        """
Stas Bekman's avatar
Stas Bekman committed
3294
        Perform an evaluation step on `model` using `inputs`.
3295
3296
3297
3298

        Subclass and override to inject custom behavior.

        Args:
3299
            model (`nn.Module`):
3300
                The model to evaluate.
3301
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3302
3303
3304
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3305
3306
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3307
                Whether or not to return the loss only.
3308
            ignore_keys (`Lst[str]`, *optional*):
3309
3310
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3311
3312

        Return:
3313
3314
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3315
        """
3316
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
3317
3318
3319
3320
3321
3322
3323
3324
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

3325
        inputs = self._prepare_inputs(inputs)
3326
3327
3328
3329
3330
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3331

3332
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
3333
        if has_labels or loss_without_labels:
3334
3335
3336
3337
3338
3339
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3340
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3341
3342
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
3343
                if has_labels or loss_without_labels:
Sylvain Gugger's avatar
Sylvain Gugger committed
3344
3345
3346
3347
3348
3349
3350
3351
3352
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3353
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3354
3355
3356
3357
3358
3359
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3360
            else:
3361
                if has_labels or loss_without_labels:
3362
                    with self.compute_loss_context_manager():
3363
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3364
                    loss = loss.mean().detach()
3365

Sylvain Gugger's avatar
Sylvain Gugger committed
3366
3367
3368
3369
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3370
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3371
                    loss = None
3372
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3373
3374
3375
3376
3377
3378
3379
3380
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3381
3382
3383
3384

        if prediction_loss_only:
            return (loss, None, None)

3385
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3386
3387
3388
3389
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3390
3391
3392

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3393
3394
3395
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3396
3397

        Args:
3398
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3399
3400
3401
                The inputs and targets of the model.

        Returns:
3402
            `int`: The number of floating-point operations.
3403
        """
3404
3405
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3406
3407
        else:
            return 0
3408

3409
    def init_git_repo(self, at_init: bool = False):
3410
        """
3411
        Initializes a git repo in `self.args.hub_model_id`.
3412
3413
3414
3415
3416
3417

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3418
        """
3419
        if not self.is_world_process_zero():
3420
            return
3421
        if self.args.hub_model_id is None:
3422
            repo_name = Path(self.args.output_dir).absolute().name
3423
3424
        else:
            repo_name = self.args.hub_model_id
3425
3426
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3427

3428
3429
        # Make sure the repo exists.
        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
3430
        try:
3431
            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3432
        except EnvironmentError:
3433
            if self.args.overwrite_output_dir and at_init:
3434
3435
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
3436
                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3437
3438
3439
3440
            else:
                raise

        self.repo.git_pull()
3441
3442

        # By default, ignore the checkpoint folders
3443
3444
3445
3446
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3447
3448
3449
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3450
3451
3452
3453
        # Add "*.sagemaker" to .gitignore if using SageMaker
        if os.environ.get("SM_TRAINING_ENV"):
            self._add_sm_patterns_to_gitignore()

3454
3455
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3456
3457
3458
3459
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3460
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3461
3462
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3463
3464
3465
3466
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3467
    ):
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3493
3494
3495
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3496
3497
3498
3499
3500
3501
3502
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3503
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3504
3505
3506
3507
3508
3509
3510
3511
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3546
3547
3548
            _, self.push_in_progress = self.repo.push_to_hub(
                commit_message=commit_message, blocking=False, auto_lfs_prune=True
            )
3549
3550
3551
3552
3553
3554
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3555
        """
3556
        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3557
3558

        Parameters:
3559
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3560
                Message to commit while pushing.
3561
3562
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3563
            kwargs:
3564
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3565
3566

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3567
3568
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3569
        """
3570
3571
3572
3573
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3574

3575
3576
        model_name = kwargs.pop("model_name", None)
        if model_name is None and self.args.should_save:
3577
3578
3579
3580
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3581

3582
3583
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3584
        self.save_model(_internal_call=True)
3585
3586
3587
3588
3589

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3590
3591
3592
3593
3594
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3595
3596
3597
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3598
3599
3600
3601
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3602
3603
3604
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3605
3606
3607
3608
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3609

3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3621
    ) -> EvalLoopOutput:
3622
        """
3623
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3624
3625
3626

        Works both with or without labels.
        """
3627
3628
        args = self.args

3629
3630
3631
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3632
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3633
3634

        # if eval is called w/o train init deepspeed here
3635
        if args.deepspeed and not self.deepspeed:
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

3648
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3649

3650
3651
3652
3653
3654
3655
3656
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3657
3658
3659
3660
3661
3662
3663
3664
3665

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3666
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3667

3668
        world_size = max(1, args.world_size)
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3679
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3680
3681
3682
3683

        model.eval()

        if is_torch_tpu_available():
3684
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3685

3686
        if args.past_index >= 0:
3687
3688
3689
3690
3691
3692
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3693
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3694

3695
3696
3697
3698
3699
3700
3701
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3702
3703
3704
3705
3706
3707
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3708
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3709
3710

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3711
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3712
3713
3714
3715
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3716
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3717
3718

                # Set back to None to begin a new accumulation
3719
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3720

3721
        if args.past_index and hasattr(self, "_past"):
3722
3723
3724
3725
3726
3727
3728
3729
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3730
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3731
3732
3733
3734

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3735
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3736
3737

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3738
3739
3740
3741
3742
3743
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

3758
        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813

    def _add_sm_patterns_to_gitignore(self) -> None:
        """Add SageMaker Checkpointing patterns to .gitignore file."""
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return

        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Get current .gitignore content
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Add the patterns to .gitignore
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write the .gitignore file if it has changed
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        self.repo.git_add(".gitignore")

        # avoid race condition with git status
        time.sleep(0.5)

        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()