trainer.py 182 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34
35


36
# Integrations must be imported before ML frameworks:
37
38
# isort: off
from .integrations import (
39
    default_hp_search_backend,
40
    get_reporting_integration_callbacks,
41
    hp_params,
42
    is_fairscale_available,
43
    is_optuna_available,
44
    is_ray_tune_available,
45
    is_sigopt_available,
46
    is_wandb_available,
47
48
    run_hp_search_optuna,
    run_hp_search_ray,
49
    run_hp_search_sigopt,
50
    run_hp_search_wandb,
51
)
52

53
54
# isort: on

55
56
import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
57
import torch.distributed as dist
58
from huggingface_hub import Repository, create_repo
59
60
from packaging import version
from torch import nn
61
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
62

63
64
from . import __version__
from .configuration_utils import PretrainedConfig
65
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
66
from .debug_utils import DebugOption, DebugUnderflowOverflow
67
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
68
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
69
from .modelcard import TrainingSummary
70
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
71
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
72
from .optimization import Adafactor, get_scheduler
73
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10
74
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
75
76
77
78
79
80
81
82
83
84
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
85
    DistributedTensorGatherer,
86
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
87
    LabelSmoother,
88
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
91
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
92
    find_batch_size,
93
    get_model_param_count,
94
    get_module_class_from_name,
95
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
98
99
100
101
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
102
103
104
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
105
    EvalLoopOutput,
106
    EvalPrediction,
107
    FSDPOption,
108
    HPSearchBackend,
109
110
    HubStrategy,
    IntervalStrategy,
111
    PredictionOutput,
112
    RemoveColumnsCollator,
113
    ShardedDDPOption,
114
    TrainerMemoryTracker,
115
116
117
    TrainOutput,
    default_compute_objective,
    default_hp_space,
118
    denumpify_detensorize,
119
    enable_full_determinism,
120
    find_executable_batch_size,
121
    get_last_checkpoint,
122
    has_length,
123
    number_of_arguments,
124
    seed_worker,
125
    set_seed,
126
    speed_metrics,
127
)
128
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
129
from .utils import (
130
131
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
132
    CONFIG_NAME,
133
134
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
135
    WEIGHTS_INDEX_NAME,
136
    WEIGHTS_NAME,
137
    can_return_loss,
138
    find_labels,
139
    get_full_repo_name,
140
    is_accelerate_available,
141
142
143
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
144
    is_ipex_available,
145
    is_peft_available,
146
    is_safetensors_available,
147
148
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
149
    is_torch_compile_available,
150
    is_torch_neuroncore_available,
151
152
    is_torch_tpu_available,
    logging,
153
    strtobool,
154
)
155
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
156
157


158
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
159

Sylvain Gugger's avatar
Sylvain Gugger committed
160
DEFAULT_CALLBACKS = [DefaultFlowCallback]
161
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
162

163
164
165
166
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
167

168
169
if is_apex_available():
    from apex import amp
170

171
172
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
173

174
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
175
176
177
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met

178
if is_fairscale_available():
179
    dep_version_check("fairscale")
180
    import fairscale
181
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
182
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
183
    from fairscale.nn.wrap import auto_wrap
184
185
186
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

187

Sylvain Gugger's avatar
Sylvain Gugger committed
188
189
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
190
191
192
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
193
194

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
195
196
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
197

198

199
200
201
202
if is_safetensors_available():
    import safetensors.torch


203
204
205
206
if is_peft_available():
    from peft import PeftModel


207
if is_accelerate_available():
208
    from accelerate import Accelerator, skip_first_batches
209
    from accelerate import __version__ as accelerate_version
210
    from accelerate.utils import DistributedDataParallelKwargs
211

212
213
214
215
216
217
218
219
    if version.parse(accelerate_version) > version.parse("0.20.3"):
        from accelerate.utils import (
            load_fsdp_model,
            load_fsdp_optimizer,
            save_fsdp_model,
            save_fsdp_optimizer,
        )

220

221
222
223
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
224
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
225
226


227
228
229
230
231
232
233
234
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
235
236
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
237
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
238
239

    Args:
240
241
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
242

243
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
244

Sylvain Gugger's avatar
Sylvain Gugger committed
245
246
247
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.
248
249
250
251

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
252
253
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
254
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
255
256
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
257
258
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
259
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
260
261
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
262
263
264
265
266
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
267
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
268
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
269
270
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
271
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
272
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
273
274
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
275
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
276
277
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
278

279
280
281
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
282
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
283
284
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
285
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
286
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
287
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
288

289
290
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
291
292
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
293
294
295
296
297
298
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
299

300
301
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
304
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
305
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
306
307
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
308
309
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
310
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
311
312
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
313
314
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
315

Julien Chaumond's avatar
Julien Chaumond committed
316
317
    """

318
    # Those are used as methods of the Trainer in examples.
319
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
320

Julien Chaumond's avatar
Julien Chaumond committed
321
322
    def __init__(
        self,
323
        model: Union[PreTrainedModel, nn.Module] = None,
324
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
325
326
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
327
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
328
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
329
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
330
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
331
        callbacks: Optional[List[TrainerCallback]] = None,
332
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
333
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
334
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
335
        if args is None:
336
337
338
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
339
340
        self.args = args
        # Seed must be set before instantiating the model when using model
341
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
342
        self.hp_name = None
343
        self.deepspeed = None
344
        self.is_in_train = False
345

346
        self.create_accelerator_and_postprocess()
347

348
349
350
351
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

352
        # set the correct log level depending on the node
353
        log_level = args.get_process_log_level()
354
355
        logging.set_verbosity(log_level)

356
357
358
        # force device and distributed setup init explicitly
        args._setup_devices

359
360
361
362
363
364
365
366
367
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
368
369
370
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
371
372
373
                    FutureWarning,
                )
            self.model_init = model_init
374

375
376
377
378
379
380
381
382
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

383
384
385
386
387
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

388
389
390
391
392
393
        if getattr(model, "hf_device_map", None) is not None:
            devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
            if len(devices) > 1:
                self.is_model_parallel = True
            else:
                self.is_model_parallel = self.args.device != torch.device(devices[0])
394
395
396
397
398
399
400

            # warn users
            logger.info(
                "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
                " to `True` to avoid any unexpected behavior such as device placement mismatching."
            )

401
        # At this stage the model is already loaded
402
403
        if getattr(model, "is_quantized", False):
            if getattr(model, "_is_quantized_training_enabled", False):
404
405
406
407
408
409
410
411
412
413
414
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )
415

416
417
418
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
419
            if self.is_deepspeed_enabled:
420
421
422
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
423
424
425
426
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
427
            if args.parallel_mode != ParallelMode.DISTRIBUTED:
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

443
444
        self.fsdp = None
        if len(args.fsdp) > 0:
445
            if self.is_deepspeed_enabled:
446
447
448
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
449
            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
450
451
                raise ValueError("Using fsdp only works in distributed training.")

452
453
454
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
455
            # below is the current alternative.
456
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
457
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
458

459
            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
460
461
462
463
464

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
465
466
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
467

468
            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
469
            if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
470
471
                "backward_prefetch", []
            ):
472
473
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

Seung-Moo Yang's avatar
Seung-Moo Yang committed
474
475
476
            self.forward_prefetch = False
            if self.args.fsdp_config.get("forward_prefect", False):
                self.forward_prefetch = True
477

478
479
480
481
            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

482
        # one place to sort out whether to place the model on device or not
483
484
485
486
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
487
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
488
        # 4. Sharded DDP - same as MP
489
        # 5. FSDP - same as MP
490
        self.place_model_on_device = args.place_model_on_device
491
492
        if (
            self.is_model_parallel
493
            or self.is_deepspeed_enabled
494
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
495
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
496
            or (self.fsdp is not None)
497
            or self.is_fsdp_enabled
498
        ):
499
500
            self.place_model_on_device = False

501
502
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
503
504
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
505
        self.tokenizer = tokenizer
506

507
        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
Sylvain Gugger's avatar
Sylvain Gugger committed
508
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
509
510
511

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
512
            self.args._n_gpu = 1
513
514
515
516
517

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
518
        self.compute_metrics = compute_metrics
519
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
520
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
521
522
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
523
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
524
525
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
541
        if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
542
543
544
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
545
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
546
547
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
548
549
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
550
551
552
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
553
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
554

555
556
557
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

558
559
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
560
            self.init_git_repo(at_init=True)
561
562
563
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
564
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
565
566
                dist.barrier()

567
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
568
            os.makedirs(self.args.output_dir, exist_ok=True)
569

570
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
571
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
572

573
574
575
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

576
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
577
578
579
580
            raise ValueError(
                "The train_dataset does not implement __len__, max_steps has to be specified. "
                "The number of steps needs to be known in advance for the learning rate scheduler."
            )
581

582
583
584
585
586
587
588
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

589
        self._signature_columns = None
590

591
592
        # Mixed precision setup
        self.use_apex = False
593
594
        self.use_cuda_amp = False
        self.use_cpu_amp = False
595

596
597
598
599
600
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
618

619
        if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
620
            if args.half_precision_backend == "auto":
621
                if args.device == torch.device("cpu"):
622
623
624
625
626
627
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
628
                else:
629
                    args.half_precision_backend = "cuda_amp"
630

631
            logger.info(f"Using {args.half_precision_backend} half precision backend")
632

633
        self.do_grad_scaling = False
634
        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
635
            # deepspeed and SageMaker Model Parallel manage their own half precision
636
637
638
639
640
641
642
643
644
645
646
647
648
            if self.sharded_ddp is not None:
                if args.half_precision_backend == "cuda_amp":
                    self.use_cuda_amp = True
                    self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                    #  bf16 does not need grad scaling
                    self.do_grad_scaling = self.amp_dtype == torch.float16
                    if self.do_grad_scaling:
                        if self.sharded_ddp is not None:
                            self.scaler = ShardedGradScaler()
                        elif self.fsdp is not None:
                            from torch.distributed.fsdp.sharded_grad_scaler import (
                                ShardedGradScaler as FSDPShardedGradScaler,
                            )
649

650
651
652
                            self.scaler = FSDPShardedGradScaler()
                        elif is_torch_tpu_available():
                            from torch_xla.amp import GradScaler
653

654
655
656
657
658
659
660
                            self.scaler = GradScaler()
                        else:
                            self.scaler = torch.cuda.amp.GradScaler()
                elif args.half_precision_backend == "cpu_amp":
                    self.use_cpu_amp = True
                    self.amp_dtype = torch.bfloat16
            elif args.half_precision_backend == "apex":
661
662
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
663
664
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
665
666
667
                    )
                self.use_apex = True

668
669
670
671
672
673
674
675
676
677
678
679
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
680
681
682
683
684
685
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

686
687
688
689
690
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
691
        self.control = TrainerControl()
692
693
694
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
695
        self.hp_search_backend = None
696
        self.use_tune_checkpoints = False
697
        default_label_names = find_labels(self.model.__class__)
698
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
699
        self.can_return_loss = can_return_loss(self.model.__class__)
Sylvain Gugger's avatar
Sylvain Gugger committed
700
701
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

702
703
704
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

705
706
707
        # very last
        self._memory_tracker.stop_and_update_metrics()

708
709
        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
710
            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
711

Sylvain Gugger's avatar
Sylvain Gugger committed
712
713
    def add_callback(self, callback):
        """
714
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
715
716

        Args:
717
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
718
719
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
720
721
722
723
724
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
725
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
726

727
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
728
729

        Args:
730
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
731
732
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
733
734

        Returns:
735
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
736
737
738
739
740
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
741
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
742
743

        Args:
744
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
745
746
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
747
748
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
749

Sylvain Gugger's avatar
Sylvain Gugger committed
750
751
752
753
754
755
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

756
    def _set_signature_columns_if_needed(self):
757
758
759
760
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
761
762
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
763

764
765
766
767
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
768
        signature_columns = self._signature_columns
769
770

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
771
        if len(ignored_columns) > 0:
772
            dset_description = "" if description is None else f"in the {description} set"
773
774
775
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
776
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
777
                " you can safely ignore this message."
778
            )
779

780
        columns = [k for k in signature_columns if k in dataset.column_names]
781

782
783
784
785
786
787
788
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
789

790
791
792
793
794
795
796
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
797
        signature_columns = self._signature_columns
798
799
800
801
802
803
804
805
806
807

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

808
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
809
        if self.train_dataset is None or not has_length(self.train_dataset):
810
            return None
811
812
813

        # Build the sampler.
        if self.args.group_by_length:
814
815
816
817
818
819
820
821
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
822
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
823
824
825
826
827
828
            return LengthGroupedSampler(
                self.args.train_batch_size * self.args.gradient_accumulation_steps,
                dataset=self.train_dataset,
                lengths=lengths,
                model_input_name=model_input_name,
            )
829
830

        else:
831
            return RandomSampler(self.train_dataset)
832
833
834

    def get_train_dataloader(self) -> DataLoader:
        """
835
        Returns the training [`~torch.utils.data.DataLoader`].
836

837
838
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
839
840
841
842
843

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
844

845
        train_dataset = self.train_dataset
846
        data_collator = self.data_collator
847
848
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
849
850
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
851

852
853
854
855
856
857
        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }
858

859
860
861
862
        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker
863

864
        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
Julien Chaumond's avatar
Julien Chaumond committed
865

866
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
886
            return None
Lysandre Debut's avatar
Lysandre Debut committed
887

Julien Chaumond's avatar
Julien Chaumond committed
888
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
889
        """
890
        Returns the evaluation [`~torch.utils.data.DataLoader`].
891

892
893
        Subclass and override this method if you want to inject some custom behavior.

894
        Args:
895
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
896
897
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
898
        """
Julien Chaumond's avatar
Julien Chaumond committed
899
900
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
901
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
902
        data_collator = self.data_collator
903

904
905
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
906
907
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
908

909
910
911
912
913
914
        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }
915

916
917
918
        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
919

920
        return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
Julien Chaumond's avatar
Julien Chaumond committed
921
922

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
923
        """
924
        Returns the test [`~torch.utils.data.DataLoader`].
925

926
927
        Subclass and override this method if you want to inject some custom behavior.

928
        Args:
929
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
930
931
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
932
        """
933
934
        data_collator = self.data_collator

935
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
936
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
937
938
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
939

940
941
942
943
944
945
        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }
946

947
948
949
        if not isinstance(test_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
Lysandre Debut's avatar
Lysandre Debut committed
950

951
        # We use the same batch_size as for eval.
952
        return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
Lysandre Debut's avatar
Lysandre Debut committed
953

954
    def create_optimizer_and_scheduler(self, num_training_steps: int):
955
956
957
        """
        Setup the optimizer and the learning rate scheduler.

958
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
959
960
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
961
962
        """
        self.create_optimizer()
963
964
965
966
967
968
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
969
970
971
972
973

    def create_optimizer(self):
        """
        Setup the optimizer.

974
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
975
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
976
        """
977
978
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

979
        if self.optimizer is None:
980
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
981
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
982
983
            optimizer_grouped_parameters = [
                {
984
985
986
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
987
988
989
                    "weight_decay": self.args.weight_decay,
                },
                {
990
991
992
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
993
994
995
                    "weight_decay": 0.0,
                },
            ]
996
997
998

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

999
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1000
1001
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
1002
1003
                    optim=optimizer_cls,
                    **optimizer_kwargs,
1004
1005
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1006
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1007
1008
1009
1010
1011
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

Stas Bekman's avatar
Stas Bekman committed
1012
                    skipped = 0
1013
                    for module in opt_model.modules():
1014
                        if isinstance(module, nn.Embedding):
1015
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
1016
                            logger.info(f"skipped {module}: {skipped/2**20}M params")
1017
1018
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
1019
                    logger.info(f"skipped: {skipped/2**20}M params")
Sylvain Gugger's avatar
Sylvain Gugger committed
1020

Sylvain Gugger's avatar
Sylvain Gugger committed
1021
1022
1023
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1024
1025
        return self.optimizer

1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
1036
1037
1038
1039
1040
1041
1042
1043

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

1044
        optimizer_kwargs = {"lr": args.learning_rate}
1045

1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1058
        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
1059
1060
1061
1062
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1063
1064
            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
                optimizer_kwargs.update({"fused": True})
1065
1066
1067
1068
1069
1070
1071
1072
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1073
1074
1075
1076
1077
1078
1079
1080
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
        elif args.optim in [
            OptimizerNames.ADAMW_BNB,
            OptimizerNames.ADAMW_8BIT,
            OptimizerNames.PAGED_ADAMW,
            OptimizerNames.PAGED_ADAMW_8BIT,
            OptimizerNames.LION,
            OptimizerNames.LION_8BIT,
            OptimizerNames.PAGED_LION,
            OptimizerNames.PAGED_LION_8BIT,
        ]:
            try:
                from bitsandbytes.optim import AdamW, Lion

                is_paged = False
                optim_bits = 32
                optimizer_cls = None
                additional_optim_kwargs = adam_kwargs
                if "paged" in args.optim:
                    is_paged = True
                if "8bit" in args.optim:
                    optim_bits = 8
                if "adam" in args.optim:
                    optimizer_cls = AdamW
                elif "lion" in args.optim:
                    optimizer_cls = Lion
                    additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}

                bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
                optimizer_kwargs.update(additional_optim_kwargs)
                optimizer_kwargs.update(bnb_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
1113
1114
1115
1116
1117
1118
1119
1120
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
                        "compensation_buffer_dtype": getattr(
                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
                        ),
                    }
                )
            except ImportError:
                raise ValueError("Please install https://github.com/pytorch/torchdistx")
1141
1142
1143
1144
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1145
1146
1147
1148
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1149
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1150
        """
1151
1152
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1153
1154
1155
1156

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1157
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1158
1159
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1160
                optimizer=self.optimizer if optimizer is None else optimizer,
1161
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1162
                num_training_steps=num_training_steps,
1163
            )
1164
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1165

1166
    def num_examples(self, dataloader: DataLoader) -> int:
1167
        """
1168
1169
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1170
        """
1171
        try:
1172
1173
1174
1175
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1176
1177
1178
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1179

1180
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1181
        """HP search setup code"""
1182
1183
        self._trial = trial

1184
1185
        if self.hp_search_backend is None or trial is None:
            return
1186
1187
1188
1189
1190
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1191
1192
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1193
1194
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1195

1196
1197
        for key, value in params.items():
            if not hasattr(self.args, key):
1198
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1199
1200
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1201
                )
1202
                continue
1203
1204
1205
1206
1207
1208
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1209
            logger.info(f"Trial: {trial.params}")
1210
1211
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1212
1213
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1214
1215
1216
        if self.is_deepspeed_enabled:
            if self.args.deepspeed is None:
                raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
1217
            # Rebuild the deepspeed config to reflect the updated training parameters
1218
1219
            from accelerate.utils import DeepSpeedPlugin

1220
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1221

1222
1223
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1224
1225
            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
        self.create_accelerator_and_postprocess()
1226

1227
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1228
1229
        if self.hp_search_backend is None or trial is None:
            return
1230
        self.objective = self.compute_objective(metrics.copy())
1231
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1232
1233
            import optuna

1234
            trial.report(self.objective, step)
1235
            if trial.should_prune():
1236
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1237
1238
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1239
1240
            from ray import tune

1241
            if self.control.should_save:
1242
                self._tune_save_checkpoint()
1243
1244
            tune.report(objective=self.objective, **metrics)

1245
    def _tune_save_checkpoint(self):
1246
1247
        from ray import tune

1248
1249
        if not self.use_tune_checkpoints:
            return
1250
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1251
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1252
            self.save_model(output_dir, _internal_call=True)
1253
            if self.args.should_save:
1254
1255
1256
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1257

1258
    def call_model_init(self, trial=None):
1259
        model_init_argcount = number_of_arguments(self.model_init)
1260
1261
1262
1263
1264
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1265
1266
1267
1268
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1269
1270
1271

        return model

1272
1273
1274
1275
1276
1277
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            example_batch = next(iter(dataloader))
1278
            example_batch = self._prepare_inputs(example_batch)
1279
1280
            try:
                jit_model = model.eval()
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
                with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
1298
                jit_model = torch.jit.freeze(jit_model)
1299
1300
1301
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
1302
                model = jit_model
1303
1304
1305
                self.use_cpu_amp = False
                self.use_cuda_amp = False
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
1306
1307
1308
1309
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1310
1311
1312
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1313
1314
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1315
1316
1317
1318
1319
1320
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
1321
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
1322
            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
1323
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
1324
1325
1326
        else:
            if not model.training:
                model.train()
1327
1328
1329
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1330
1331
1332

        return model

1333
    def _wrap_model(self, model, training=True, dataloader=None):
1334
1335
1336
1337
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

Sylvain Gugger's avatar
Sylvain Gugger committed
1338
1339
1340
1341
1342
1343
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1344
1345
1346
1347
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1348
1349
1350
1351
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

1352
1353
        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
1354
            model = nn.DataParallel(model)
1355

1356
        if self.args.jit_mode_eval:
1357
            start_time = time.time()
1358
            model = self.torch_jit_model_eval(model, dataloader, training)
1359
            self.jit_compilation_time = round(time.time() - start_time, 4)
1360

1361
1362
1363
1364
1365
1366
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1367
1368
1369
1370
1371
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1372
                mixed_precision = self.args.fp16 or self.args.bf16
1373
1374
1375
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1376
1377
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1378
                self.model = model = FullyShardedDDP(
1379
1380
1381
1382
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1383
                ).to(self.args.device)
1384
        # Distributed training using PyTorch FSDP
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
        elif self.fsdp is not None and self.args.fsdp_config["xla"]:
            try:
                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                from torch_xla.distributed.fsdp import checkpoint_module
                from torch_xla.distributed.fsdp.wrap import (
                    size_based_auto_wrap_policy,
                    transformer_auto_wrap_policy,
                )
            except ImportError:
                raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
            auto_wrap_policy = None
            auto_wrapper_callable = None
            if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                auto_wrap_policy = functools.partial(
                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                )
            elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                transformer_cls_to_wrap = set()
                for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                    transformer_cls = get_module_class_from_name(model, layer_class)
                    if transformer_cls is None:
                        raise Exception("Could not find the transformer layer class to wrap in the model.")
                    else:
                        transformer_cls_to_wrap.add(transformer_cls)
                auto_wrap_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    # Transformer layer class to wrap
                    transformer_layer_cls=transformer_cls_to_wrap,
1413
                )
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
            fsdp_kwargs = self.args.xla_fsdp_config
            if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                def auto_wrapper_callable(m, *args, **kwargs):
                    return FSDP(checkpoint_module(m), *args, **kwargs)

            # Wrap the base model with an outer FSDP wrapper
            self.model = model = FSDP(
                model,
                auto_wrap_policy=auto_wrap_policy,
                auto_wrapper_callable=auto_wrapper_callable,
                **fsdp_kwargs,
            )
1427

1428
1429
1430
1431
1432
1433
1434
            # Patch `xm.optimizer_step` should not reduce gradients in this case,
            # as FSDP does not need gradient reduction over sharded parameters.
            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                loss = optimizer.step(**optimizer_args)
                if barrier:
                    xm.mark_step()
                return loss
1435

1436
            xm.optimizer_step = patched_optimizer_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1437
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1438
1439
1440
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1441
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
1442
1443
            if is_torch_neuroncore_available():
                return model
1444
            kwargs = {}
1445
            if self.args.ddp_find_unused_parameters is not None:
1446
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1447
1448
1449
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1450
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1451
            else:
1452
1453
1454
1455
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1456
1457

            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
1458
1459
1460

        return model

1461
1462
    def train(
        self,
1463
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1464
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1465
        ignore_keys_for_eval: Optional[List[str]] = None,
1466
        **kwargs,
1467
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1468
1469
1470
1471
        """
        Main training entry point.

        Args:
1472
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1473
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1474
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1475
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1476
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1477
                The trial run or the hyperparameter dictionary for hyperparameter search.
1478
            ignore_keys_for_eval (`List[str]`, *optional*)
1479
1480
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1481
1482
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1483
        """
1484
1485
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1486
1487
1488
1489

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1490
1491
        args = self.args

1492
1493
        self.is_in_train = True

1494
1495
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1496
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1497
            self._move_model_to_device(self.model, args.device)
1498

1499
1500
1501
1502
1503
1504
1505
1506
1507
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1508
1509
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1510
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1511

1512
        # Model re-init
1513
        model_reloaded = False
1514
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1515
            # Seed must be set before instantiating the model when using model_init.
1516
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1517
1518
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1519
1520
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1521

1522
        # Load potential model checkpoint
1523
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1524
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1525
            if resume_from_checkpoint is None:
1526
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1527

1528
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
1529
            self._load_from_checkpoint(resume_from_checkpoint)
1530

1531
1532
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1533
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1534
                self._move_model_to_device(self.model, args.device)
1535
1536
            self.model_wrapped = self.model

1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
1550
        self.accelerator.free_memory()
1551
        self._train_batch_size = batch_size
1552
        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
1553
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1554
        train_dataloader = self.get_train_dataloader()
1555
1556
1557
1558
1559

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1560
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1561
1562
1563
1564
1565

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1566
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1567
            num_examples = self.num_examples(train_dataloader)
1568
1569
1570
1571
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1572
                )
1573
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1574
1575
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1576
            else:
1577
1578
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1579
1580
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1581
            max_steps = args.max_steps
1582
1583
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1584
            num_update_steps_per_epoch = max_steps
1585
            num_examples = total_train_batch_size * args.max_steps
1586
            num_train_samples = args.max_steps * total_train_batch_size
1587
1588
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1589
1590
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1591
            )
Julien Chaumond's avatar
Julien Chaumond committed
1592

1593
1594
1595
1596
1597
1598
1599
1600
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps and args.logging_steps < 1:
            args.logging_steps = math.ceil(max_steps * args.logging_steps)
        if args.eval_steps and args.eval_steps < 1:
            args.eval_steps = math.ceil(max_steps * args.eval_steps)
        if args.save_steps and args.save_steps < 1:
            args.save_steps = math.ceil(max_steps * args.save_steps)

1601
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1602
1603
1604
1605
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1606
1607
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1608
1609
1610
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1611

1612
        delay_optimizer_creation = (
1613
1614
1615
1616
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1617
        )
1618
1619
1620
1621
1622

        if self.is_deepspeed_enabled:
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

        if not delay_optimizer_creation:
1623
1624
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1625
        self.state = TrainerState()
1626
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1627

1628
1629
1630
1631
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1632
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1633

1634
1635
1636
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1637
1638
1639
1640
        # as the model is wrapped, don't use `accelerator.prepare`
        # this is for unhandled cases such as
        # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
        use_accelerator_prepare = True if model is self.model else False
1641

1642
1643
1644
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1645
        # prepare using `accelerator` prepare
1646
        if use_accelerator_prepare:
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
            if hasattr(self.lr_scheduler, "step"):
                if self.use_apex:
                    model = self.accelerator.prepare(self.model)
                else:
                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
            else:
                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                    self.model, self.optimizer, self.lr_scheduler
                )
1657

1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
        if self.is_fsdp_enabled:
            self.model = model

        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model_wrapped

        # deepspeed ckpt loading
        if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
            deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)

1673
1674
1675
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1676
1677
        # important: at this point:
        # self.model         is the Transformers Model
1678
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1679

Julien Chaumond's avatar
Julien Chaumond committed
1680
1681
        # Train!
        logger.info("***** Running training *****")
1682
1683
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
1684
        logger.info(f"  Instantaneous batch size per device = {self._train_batch_size:,}")
1685
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
1686
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1687
1688
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
Julien Chaumond's avatar
Julien Chaumond committed
1689

1690
        self.state.epoch = 0
1691
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1692
1693
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1694
        steps_trained_progress_bar = None
1695

Julien Chaumond's avatar
Julien Chaumond committed
1696
        # Check if continuing training from a checkpoint
1697
        if resume_from_checkpoint is not None and os.path.isfile(
1698
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1699
        ):
1700
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1701
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1702
            if not args.ignore_data_skip:
1703
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1704
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1705
1706
            else:
                steps_trained_in_current_epoch = 0
1707
1708

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1709
1710
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1711
            if not args.ignore_data_skip:
1712
1713
1714
1715
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first"
                    f" {steps_trained_in_current_epoch} batches in the first epoch."
                )
1716

Sylvain Gugger's avatar
Sylvain Gugger committed
1717
1718
1719
1720
1721
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1722
1723
1724
1725
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
1726
1727
1728
1729
1730
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1731
1732
1733
1734
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1735
1736
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1737

1738
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1739
        tr_loss = torch.tensor(0.0).to(args.device)
1740
1741
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1742
        self._globalstep_last_logged = self.state.global_step
1743
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1744

1745
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1746

1747
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1748
        if not args.ignore_data_skip:
1749
            for epoch in range(epochs_trained):
1750
1751
                for _ in train_dataloader:
                    break
1752

1753
        total_batched_samples = 0
1754
        for epoch in range(epochs_trained, num_train_epochs):
1755
            epoch_iterator = train_dataloader
1756

1757
            # Reset the past mems state at the beginning of each epoch if necessary.
1758
            if args.past_index >= 0:
1759
1760
                self._past = None

1761
            steps_in_epoch = (
1762
1763
1764
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1765
            )
1766
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1767

1768
1769
1770
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1771
            rng_to_sync = False
1772
            steps_skipped = 0
1773
            if steps_trained_in_current_epoch > 0:
1774
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
1775
                steps_skipped = steps_trained_in_current_epoch
1776
1777
1778
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

1779
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1780
            for step, inputs in enumerate(epoch_iterator):
1781
                total_batched_samples += 1
1782
1783
1784
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False
Julien Chaumond's avatar
Julien Chaumond committed
1785
1786
1787
1788

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1789
1790
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1791
1792
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1793
                    continue
1794
1795
1796
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1797

1798
1799
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1800

1801
                with self.accelerator.accumulate(model):
1802
1803
                    tr_loss_step = self.training_step(model, inputs)

1804
1805
1806
1807
1808
1809
1810
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1811
1812
1813
                else:
                    tr_loss += tr_loss_step

1814
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1815

1816
1817
1818
                # should this be under the accumulate context manager?
                # the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered
                # in accelerate
1819
                if total_batched_samples % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1820
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1821
                    steps_in_epoch <= args.gradient_accumulation_steps
1822
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1823
                ):
1824
                    # Gradient clipping
1825
                    if args.max_grad_norm is not None and args.max_grad_norm > 0:
1826
1827
                        # deepspeed does its own clipping

1828
                        if self.do_grad_scaling:
1829
1830
1831
1832
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1833
1834
1835
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1836
1837
1838
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1839
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1840
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1841
1842
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1843
                            model.clip_grad_norm_(args.max_grad_norm)
1844
                        elif self.use_apex:
1845
                            # Revert to normal clipping otherwise, handling Apex or full precision
1846
                            nn.utils.clip_grad_norm_(
1847
1848
1849
1850
1851
1852
                                amp.master_params(self.optimizer),
                                args.max_grad_norm,
                            )
                        else:
                            self.accelerator.clip_grad_norm_(
                                model.parameters(),
1853
                                args.max_grad_norm,
1854
1855
1856
                            )

                    # Optimizer step
1857
                    optimizer_was_run = True
1858
                    if is_torch_tpu_available():
1859
1860
1861
1862
1863
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1864
                    elif self.do_grad_scaling:
1865
                        scale_before = self.scaler.get_scale()
1866
                        self.scaler.step(self.optimizer)
1867
                        self.scaler.update()
1868
1869
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1870
                    else:
1871
                        self.optimizer.step()
1872
                        optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
Lysandre Debut's avatar
Lysandre Debut committed
1873

1874
                    if optimizer_was_run:
1875
1876
1877
                        # Delay optimizer scheduling until metrics are generated
                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                            self.lr_scheduler.step()
1878

1879
                    model.zero_grad()
1880
                    self.state.global_step += 1
1881
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
1882
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1883

1884
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
1885
1886
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1887

Sylvain Gugger's avatar
Sylvain Gugger committed
1888
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1889
                    break
1890
1891
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1892
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
1893
1894
1895
1896
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
1897

1898
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
1899
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
1900

1901
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
1902
1903
1904
1905
1906
1907
1908
1909
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1910
            if self.control.should_training_stop:
1911
                break
Julien Chaumond's avatar
Julien Chaumond committed
1912

1913
        if args.past_index and hasattr(self, "_past"):
1914
1915
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1916
1917

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
1918
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
1919
1920
1921
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
1922
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
1923
                dist.barrier()
1924
1925
            elif is_sagemaker_mp_enabled():
                smp.barrier()
1926

1927
            self._load_best_model()
1928

1929
1930
1931
1932
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

1933
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
1934
1935
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
1936
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
1937

1938
        self.is_in_train = False
1939

1940
1941
        self._memory_tracker.stop_and_update_metrics(metrics)

1942
1943
        self.log(metrics)

raghavanone's avatar
raghavanone committed
1944
1945
1946
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

1947
1948
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
raghavanone's avatar
raghavanone committed
1949
1950
1951
1952
1953
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

1954
1955
1956
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
1957

raghavanone's avatar
raghavanone committed
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

1978
1979
1980
1981
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

1982
1983
1984
1985
1986
1987
1988
1989
        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)

        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

        if not any(
Qingyang Wu's avatar
Qingyang Wu committed
1990
            os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
1991
        ):
1992
1993
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

1994
        logger.info(f"Loading model from {resume_from_checkpoint}.")
1995

1996
1997
        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
1998
1999
2000
2001
2002
2003
2004
2005
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

2006
        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
2007
            # If the model is on the GPU, it still works!
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
2022
                    state_dict = torch.load(weights_file, map_location="cpu")
2023
2024
2025
2026
2027
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
2028
            elif self.is_fsdp_enabled:
2029
                load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)
2030
2031
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
2032
2033
2034
2035
2036
                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
                else:
                    state_dict = torch.load(weights_file, map_location="cpu")

2037
2038
2039
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
2040
2041
                # release memory
                del state_dict
2042
                self._issue_warnings_after_load(load_result)
2043
2044
        else:
            # We load the sharded checkpoint
2045
2046
2047
            load_result = load_sharded_checkpoint(
                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
            )
2048
            if not is_sagemaker_mp_enabled():
2049
                self._issue_warnings_after_load(load_result)
2050
2051
2052
2053

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
2054
        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
2055
2056
2057
        best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
        best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)

2058
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2059
2060
2061
2062
2063
2064
        if (
            os.path.exists(best_model_path)
            or os.path.exists(best_safe_model_path)
            or os.path.exists(best_adapter_model_path)
            or os.path.exists(best_safe_adapter_model_path)
        ):
2065
2066
            if self.is_deepspeed_enabled:
                deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
2067
            else:
2068
                has_been_loaded = True
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
2082
2083
2084
2085
2086
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")

2087
2088
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
2089
                elif self.is_fsdp_enabled:
2090
2091
                    load_fsdp_model(
                        self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
2092
                    )
2093
                else:
2094
2095
                    if is_peft_available() and isinstance(model, PeftModel):
                        # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
2096
                        if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
2097
                            if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
2098
2099
2100
2101
2102
2103
2104
2105
                                model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
                                # Load_adapter has no return value present, modify it when appropriate.
                                from torch.nn.modules.module import _IncompatibleKeys

                                load_result = _IncompatibleKeys([], [])
                            else:
                                logger.warning(
                                    "The intermediate checkpoints of PEFT may not be saved correctly, "
2106
                                    f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
2107
2108
                                    "here are some examples https://github.com/huggingface/peft/issues/96"
                                )
2109
                                has_been_loaded = False
2110
                        else:
2111
2112
                            logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
                            has_been_loaded = False
2113
                    else:
2114
2115
2116
2117
2118
                        # We load the model state dict on the CPU to avoid an OOM error.
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")
2119

2120
2121
2122
2123
                        # If the model is on the GPU, it still works!
                        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                        # which takes *args instead of **kwargs
                        load_result = model.load_state_dict(state_dict, False)
2124
                if not is_sagemaker_mp_enabled() and has_been_loaded:
2125
                    self._issue_warnings_after_load(load_result)
2126
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2127
2128
2129
2130
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2131
                self._issue_warnings_after_load(load_result)
2132
2133
2134
2135
2136
2137
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2138
    def _issue_warnings_after_load(self, load_result):
2139
        if len(load_result.missing_keys) != 0:
2140
2141
2142
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2143
2144
                self.model.tie_weights()
            else:
2145
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2146
        if len(load_result.unexpected_keys) != 0:
2147
2148
2149
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2150

2151
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2152
        if self.control.should_log:
2153
2154
2155
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2156
            logs: Dict[str, float] = {}
2157
2158
2159
2160

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2161
2162
2163
            # reset tr_loss to zero
            tr_loss -= tr_loss

2164
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2165
            logs["learning_rate"] = self._get_learning_rate()
2166

2167
            self._total_loss_scalar += tr_loss_scalar
2168
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2169
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2170
2171
2172
2173
2174

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2175
            if isinstance(self.eval_dataset, dict):
2176
                metrics = {}
2177
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
2178
                    dataset_metrics = self.evaluate(
2179
2180
2181
2182
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
2183
                    metrics.update(dataset_metrics)
2184
2185
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2186
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2187

2188
2189
            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
2190
2191
2192
2193
                metric_to_check = self.args.metric_for_best_model
                if not metric_to_check.startswith("eval_"):
                    metric_to_check = f"eval_{metric_to_check}"
                self.lr_scheduler.step(metrics[metric_to_check])
2194

Sylvain Gugger's avatar
Sylvain Gugger committed
2195
2196
2197
2198
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2199
2200
2201
2202
2203
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2204
2205
2206
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2207
            if not os.path.isfile(rng_file):
2208
                logger.info(
2209
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2210
2211
2212
2213
2214
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2215
            if not os.path.isfile(rng_file):
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
2227
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2228
                torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
2229
            else:
2230
                try:
2231
                    torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
2232
                except Exception as e:
2233
                    logger.info(
2234
2235
2236
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2237
2238
2239
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2240
    def _save_checkpoint(self, model, trial, metrics=None):
2241
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2242
        # want to save except FullyShardedDDP.
2243
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2244

2245
        # Save model checkpoint
2246
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2247

raghavanone's avatar
raghavanone committed
2248
        if self.hp_search_backend is None and trial is None:
2249
            self.store_flos()
2250

raghavanone's avatar
raghavanone committed
2251
        run_dir = self._get_output_dir(trial=trial)
2252
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2253
        self.save_model(output_dir, _internal_call=True)
2254
        if self.is_deepspeed_enabled:
2255
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2256
            # config `stage3_gather_16bit_weights_on_model_save` is True
2257
            self.model_wrapped.save_checkpoint(output_dir)
2258
2259

        # Save optimizer and scheduler
2260
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2261
            self.optimizer.consolidate_state_dict()
2262

2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
        if self.fsdp or self.is_fsdp_enabled:
            if self.is_fsdp_enabled:
                save_fsdp_optimizer(
                    self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
                )
            else:
                # FSDP has a different interface for saving optimizer states.
                # Needs to be called on all ranks to gather all states.
                # full_optim_state_dict will be deprecated after Pytorch 2.2!
                full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
Qingyang Wu's avatar
Qingyang Wu committed
2273

2274
2275
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2276
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2277
            with warnings.catch_warnings(record=True) as caught_warnings:
2278
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2279
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2280
        elif is_sagemaker_mp_enabled():
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2296
        elif self.args.should_save and not self.is_deepspeed_enabled:
2297
            # deepspeed.save_checkpoint above saves model/optim/sched
Qingyang Wu's avatar
Qingyang Wu committed
2298
2299
2300
2301
2302
            if self.fsdp:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

2303
            with warnings.catch_warnings(record=True) as caught_warnings:
2304
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2305
            reissue_pt_warnings(caught_warnings)
2306
            if self.do_grad_scaling:
2307
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2308
2309

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2310
        if metrics is not None and self.args.metric_for_best_model is not None:
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2326
        if self.args.should_save:
2327
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2328

2329
2330
2331
2332
2333
2334
2335
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
2336
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2337
2338
2339
2340
2341
2342
2343
2344
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2345
2346
2347
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2348

2349
        if self.args.world_size <= 1:
2350
2351
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2352
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2353

2354
2355
2356
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2357
        # Maybe delete some older checkpoints.
2358
        if self.args.should_save:
2359
2360
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2361
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2362
        """If optimizer and scheduler states exist, load them."""
2363
        if checkpoint is None:
2364
2365
            return

2366
        if self.is_deepspeed_enabled:
2367
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2368
2369
            return

2370
2371
2372
2373
2374
2375
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2376
2377
2378
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2379
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2380
                with warnings.catch_warnings(record=True) as caught_warnings:
2381
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2382
2383
2384
2385
2386
2387
2388
2389
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
2390
                if is_sagemaker_mp_enabled():
2391
2392
2393
2394
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2395

2396
2397
2398
2399
2400
2401
2402
2403
2404
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2405
2406
2407

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
2408
2409
2410
2411
                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
                    if self.fsdp or self.is_fsdp_enabled:
                        if self.is_fsdp_enabled:
                            load_fsdp_optimizer(
                                self.accelerator.state.fsdp_plugin,
                                self.accelerator,
                                self.optimizer,
                                self.model,
                                checkpoint,
                            )
                        else:
                            full_osd = None
                            # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
                            if self.args.process_index == 0:
                                full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
                            # call scatter_full_optim_state_dict on all ranks
                            sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
                            self.optimizer.load_state_dict(sharded_osd)
Qingyang Wu's avatar
Qingyang Wu committed
2429
2430
2431
2432
                    else:
                        self.optimizer.load_state_dict(
                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                        )
Sylvain Gugger's avatar
Sylvain Gugger committed
2433
                with warnings.catch_warnings(record=True) as caught_warnings:
2434
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2435
                reissue_pt_warnings(caught_warnings)
2436
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2437
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2438

2439
2440
2441
2442
2443
2444
2445
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2446
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2447
        **kwargs,
2448
2449
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2450
2451
2452
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2453

2454
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2455

Sylvain Gugger's avatar
Sylvain Gugger committed
2456
2457
2458
2459
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2460
2461

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2462

2463
        Args:
2464
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2465
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2466
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2467
2468
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2469
2470
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2471
            n_trials (`int`, *optional*, defaults to 100):
2472
                The number of trial runs to test.
2473
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2474
2475
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2476
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2477
2478
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2479
2480
2481
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2482
2483
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2484

Sylvain Gugger's avatar
Sylvain Gugger committed
2485
2486
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2487
2488
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2489
2490

        Returns:
2491
2492
            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
            `run_summary` attribute for Ray backend.
2493
2494
2495
2496
2497
2498
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2499
2500
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2501
                    "To install sigopt run `pip install sigopt`."
2502
2503
2504
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2505
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2506
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2507
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2508
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2509
            )
2510
2511
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2512
2513
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2514
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2515
2516
2517
2518
2519
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2520
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2521
        self.hp_name = hp_name
2522
2523
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2524
2525
2526
2527
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2528
            HPSearchBackend.WANDB: run_hp_search_wandb,
2529
2530
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2531
2532
2533
2534

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2535
    def log(self, logs: Dict[str, float]) -> None:
2536
        """
2537
        Log `logs` on the various objects watching training.
2538
2539
2540
2541

        Subclass and override this method to inject custom behavior.

        Args:
2542
            logs (`Dict[str, float]`):
2543
2544
                The values to log.
        """
2545
        if self.state.epoch is not None:
2546
            logs["epoch"] = round(self.state.epoch, 2)
2547

2548
2549
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2550
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2551

2552
2553
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2554
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2555
        """
2556
2557
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2558
2559
2560
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
2561
            kwargs = {"device": self.args.device}
2562
            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
2563
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2564
2565
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
2566
                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
2567
2568
2569
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2570
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2571
        """
2572
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2573
2574
        handling potential state.
        """
2575
        inputs = self._prepare_input(inputs)
2576
2577
2578
2579
2580
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2581
2582
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2583

2584
2585
        return inputs

2586
2587
2588
2589
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
2590
        return self.autocast_smart_context_manager()
2591

2592
    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
2593
        """
2594
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2595
2596
        arguments, depending on the situation.
        """
2597
        if self.use_cuda_amp or self.use_cpu_amp:
2598
            if is_torch_greater_or_equal_than_1_10:
2599
                ctx_manager = (
2600
                    torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2601
                    if self.use_cpu_amp
2602
                    else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2603
                )
2604
            else:
2605
                ctx_manager = torch.cuda.amp.autocast()
2606
2607
2608
2609
2610
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2611
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2612
        """
2613
        Perform a training step on a batch of inputs.
2614
2615
2616
2617

        Subclass and override to inject custom behavior.

        Args:
2618
            model (`nn.Module`):
2619
                The model to train.
2620
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2621
2622
2623
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2624
                argument `labels`. Check your model's documentation for all accepted arguments.
2625
2626

        Return:
2627
            `torch.Tensor`: The tensor with training loss on this batch.
2628
2629
        """
        model.train()
2630
        inputs = self._prepare_inputs(inputs)
2631

Sylvain Gugger's avatar
Sylvain Gugger committed
2632
        if is_sagemaker_mp_enabled():
2633
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2634
2635
            return loss_mb.reduce_mean().detach().to(self.args.device)

2636
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2637
            loss = self.compute_loss(model, inputs)
2638

Julien Chaumond's avatar
Julien Chaumond committed
2639
2640
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2641

2642
        if self.do_grad_scaling:
2643
            self.scaler.scale(loss).backward()
2644
        elif self.use_apex:
2645
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2646
2647
                scaled_loss.backward()
        else:
2648
            self.accelerator.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2649

2650
        return loss.detach() / self.args.gradient_accumulation_steps
Julien Chaumond's avatar
Julien Chaumond committed
2651

2652
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2653
2654
2655
2656
2657
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2658
2659
2660
2661
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2662
2663
        outputs = model(**inputs)
        # Save past state if it exists
2664
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2665
2666
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2667

2668
        if labels is not None:
2669
2670
2671
2672
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2673
        else:
2674
2675
2676
2677
2678
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2679
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2680
2681
2682
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2683

2684
2685
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2686
2687
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2688
        """
2689
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2690

2691
2692
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2693
        Whether or not this process is the global main process (when training in a distributed fashion on several
2694
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2695
        """
2696
2697
2698
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2699
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2700
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2701
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2702

Sylvain Gugger's avatar
Sylvain Gugger committed
2703
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2704
        """
2705
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2706

2707
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2708
        """
2709
2710
2711
2712

        if output_dir is None:
            output_dir = self.args.output_dir

2713
        if is_torch_tpu_available():
2714
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2715
2716
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2717
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2718
            state_dict = self.model_wrapped.state_dict()
2719
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2720
                self._save(output_dir, state_dict=state_dict)
2721
2722
2723
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2724
        elif (
2725
2726
2727
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2728
            or self.is_fsdp_enabled
2729
        ):
2730
            if self.is_fsdp_enabled:
2731
                os.makedirs(output_dir, exist_ok=True)
2732
                save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
2733
2734
            else:
                state_dict = self.model.state_dict()
2735

2736
2737
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
2738
        elif self.is_deepspeed_enabled:
2739
            # this takes care of everything as long as we aren't under zero3
2740
            if self.args.should_save:
2741
2742
2743
2744
2745
2746
2747
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2748
                if self.args.should_save:
2749
2750
2751
2752
2753
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2754
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2755
2756
                # if false it will not be saved.
                # This must be called on all ranks
2757
                if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
2758
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2759
2760
2761
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2762
                    )
2763
                    self.model_wrapped.save_checkpoint(output_dir)
2764

2765
        elif self.args.should_save:
2766
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2767

Sylvain Gugger's avatar
Sylvain Gugger committed
2768
2769
2770
2771
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2772
2773
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2774
        logger.info(f"Saving model checkpoint to {output_dir}")
2775
2776
2777

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2778
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2779
2780
2781
2782

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2783
        if not isinstance(self.model, PreTrainedModel):
2784
2785
2786
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2787
                    is_main_process=self.args.should_save,
2788
2789
2790
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2791
2792
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2793
2794
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2795
        else:
2796
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2797
        if self.tokenizer is not None and self.args.should_save:
2798
            self.tokenizer.save_pretrained(output_dir)
2799

2800
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2801
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2802
2803
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2804
        logger.info(f"Saving model checkpoint to {output_dir}")
2805
2806

        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
Julien Chaumond's avatar
Julien Chaumond committed
2807
2808
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
2809
        if not isinstance(self.model, supported_classes):
2810
2811
2812
            if state_dict is None:
                state_dict = self.model.state_dict()

2813
            if isinstance(unwrap_model(self.model), supported_classes):
2814
2815
2816
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
2817
2818
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2819
2820
2821
2822
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2823
        else:
2824
2825
2826
2827
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

2828
        if self.tokenizer is not None:
2829
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2830
2831

        # Good practice: save your training arguments together with the trained model
2832
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2833

2834
    def store_flos(self):
2835
        # Storing the number of floating-point operations that went into the model
2836
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2837
2838
2839
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2840
2841
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2842
            self.state.total_flos += self.current_flos
2843
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2844

2845
2846
2847
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2848
2849
        ordering_and_checkpoint_path = []

2850
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2851
2852
2853
2854
2855

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2856
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2857
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2858
2859
2860
2861
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2862
2863
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2864
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2865
2866
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2867
2868
        return checkpoints_sorted

2869
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2870
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2871
2872
2873
            return

        # Check if we should delete older checkpoint(s)
2874
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2875
2876
2877
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2878
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2889
2890
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2891
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
2892
            shutil.rmtree(checkpoint, ignore_errors=True)
Julien Chaumond's avatar
Julien Chaumond committed
2893

2894
    def evaluate(
2895
2896
2897
2898
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2899
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2900
        """
2901
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2902

Sylvain Gugger's avatar
Sylvain Gugger committed
2903
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2904
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2905

2906
2907
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2908
        Args:
2909
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2910
2911
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
2912
                method.
2913
            ignore_keys (`List[str]`, *optional*):
2914
2915
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2916
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2917
2918
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2919

Julien Chaumond's avatar
Julien Chaumond committed
2920
        Returns:
2921
2922
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2923
        """
2924
2925
2926
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2927
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2928
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2929

2930
2931
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2932
2933
2934
2935
2936
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
2937
            ignore_keys=ignore_keys,
2938
            metric_key_prefix=metric_key_prefix,
2939
        )
Lysandre Debut's avatar
Lysandre Debut committed
2940

2941
        total_batch_size = self.args.eval_batch_size * self.args.world_size
2942
2943
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
2944
2945
2946
2947
2948
2949
2950
2951
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2952

2953
        self.log(output.metrics)
2954

2955
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
2956
2957
2958
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
2959
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
2960
2961
2962

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
2963
2964
        return output.metrics

2965
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
2966
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
2967
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2968
        """
2969
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2970

Sylvain Gugger's avatar
Sylvain Gugger committed
2971
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
2972
        will also return metrics, like in `evaluate()`.
2973
2974

        Args:
2975
2976
2977
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
2978
            ignore_keys (`List[str]`, *optional*):
2979
2980
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2981
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
2982
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
2983
                "test_bleu" if the prefix is "test" (default)
2984

2985
2986
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
2987
2988
2989
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
2990

2991
        </Tip>
2992

2993
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
2994

2995
2996
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
2997
2998
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
2999
        """
3000
3001
3002
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3003
        test_dataloader = self.get_test_dataloader(test_dataset)
3004
        start_time = time.time()
3005

3006
3007
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3008
3009
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
3010
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3011
3012
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3013
3014
3015
3016
3017
3018
3019
3020
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3021

3022
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
3023
3024
        self._memory_tracker.stop_and_update_metrics(output.metrics)

3025
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
3026

3027
    def evaluation_loop(
3028
3029
3030
3031
3032
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
3033
        metric_key_prefix: str = "eval",
3034
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3035
        """
3036
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
3037
3038
3039

        Works both with or without labels.
        """
3040
3041
3042
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
3043

3044
        # if eval is called w/o train, handle model prep here
3045
        if self.is_deepspeed_enabled and self.deepspeed is None:
3046
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
3047

3048
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
3049

3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

3068
3069
3070
3071
3072
3073
3074
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3075

3076
        batch_size = self.args.eval_batch_size
3077

3078
        logger.info(f"***** Running {description} *****")
3079
        if has_length(dataloader):
3080
3081
3082
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
3083
        logger.info(f"  Batch size = {batch_size}")
3084

Julien Chaumond's avatar
Julien Chaumond committed
3085
3086
        model.eval()

3087
3088
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
3089
        eval_dataset = getattr(dataloader, "dataset", None)
3090

3091
        if args.past_index >= 0:
3092
            self._past = None
3093

3094
3095
3096
3097
3098
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
3099
3100
        inputs_host = None

3101
3102
3103
3104
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
3105
        all_inputs = None
3106
3107
3108
3109
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
3110
        for step, inputs in enumerate(dataloader):
3111
3112
3113
3114
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
3115
3116
3117
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
3118
3119

            # Prediction step
3120
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3121
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3122

3123
3124
3125
            if is_torch_tpu_available():
                xm.mark_step()

3126
            # Update containers on host
3127
            if loss is not None:
3128
3129
                losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
                losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
3130
            if labels is not None:
3131
                labels = self.accelerator.pad_across_processes(labels)
3132
            if inputs_decode is not None:
3133
3134
                inputs_decode = self.accelerator.pad_across_processes(inputs_decode)
                inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
3135
3136
3137
3138
3139
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3140
            if logits is not None:
3141
                logits = self.accelerator.pad_across_processes(logits)
3142
3143
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
3144
                logits = self.accelerator.gather_for_metrics((logits))
3145
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
3146

3147
            if labels is not None:
3148
                labels = self.accelerator.gather_for_metrics((labels))
3149
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3150

3151
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
3152

3153
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3154
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3155
3156
3157
3158
3159
3160
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3161
3162
3163
3164
3165
3166
3167
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3168
3169
3170
3171
3172
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3173
3174

                # Set back to None to begin a new accumulation
3175
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3176

3177
        if args.past_index and hasattr(self, "_past"):
3178
3179
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3180

3181
        # Gather all remaining tensors and put them back on the CPU
3182
        if losses_host is not None:
3183
            all_losses = nested_numpify(losses_host)
3184
        if preds_host is not None:
3185
            all_preds = nested_numpify(preds_host)
3186
        if inputs_host is not None:
3187
            all_inputs = nested_numpify(inputs_host)
3188
        if labels_host is not None:
3189
            all_labels = nested_numpify(labels_host)
3190
3191

        # Number of samples
3192
        if has_length(eval_dataset):
3193
            num_samples = len(eval_dataset)
3194
3195
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3196
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3197
3198
            num_samples = eval_dataset.num_examples
        else:
3199
3200
3201
3202
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3203
3204
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3205
3206
3207

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3208
3209
3210
3211
3212
3213
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3214
3215
        else:
            metrics = {}
3216

3217
3218
3219
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3220
3221
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3222
3223
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
3224

3225
        # Prefix all keys with metric_key_prefix + '_'
3226
        for key in list(metrics.keys()):
3227
3228
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3229

3230
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3231

3232
    def _nested_gather(self, tensors, name=None):
3233
3234
3235
3236
3237
3238
3239
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3240
3241
            if name is None:
                name = "nested_gather"
3242
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3243
3244
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
Zachary Mueller's avatar
Zachary Mueller committed
3245
3246
3247
        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
            self.args.distributed_state is None and self.local_rank != -1
        ):
3248
            tensors = distributed_concat(tensors)
3249
        return tensors
3250

3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
3273
3274
3275
        # When extracting XLA graphs for compilation, max_size is 0,
        # so use inequality to avoid errors.
        if tensor.shape[1] >= max_size:
3276
3277
3278
3279
3280
3281
3282
3283
3284
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3285

3286
    def prediction_step(
3287
3288
3289
3290
3291
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3292
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3293
        """
Stas Bekman's avatar
Stas Bekman committed
3294
        Perform an evaluation step on `model` using `inputs`.
3295
3296
3297
3298

        Subclass and override to inject custom behavior.

        Args:
3299
            model (`nn.Module`):
3300
                The model to evaluate.
3301
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3302
3303
3304
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3305
3306
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3307
                Whether or not to return the loss only.
3308
            ignore_keys (`List[str]`, *optional*):
3309
3310
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3311
3312

        Return:
3313
3314
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3315
        """
3316
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
3317
3318
3319
3320
3321
3322
3323
3324
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

3325
        inputs = self._prepare_inputs(inputs)
3326
3327
3328
3329
3330
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3331

3332
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
3333
        if has_labels or loss_without_labels:
3334
3335
3336
3337
3338
3339
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3340
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3341
3342
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
3343
                if has_labels or loss_without_labels:
Sylvain Gugger's avatar
Sylvain Gugger committed
3344
3345
3346
3347
3348
3349
3350
3351
3352
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3353
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3354
3355
3356
3357
3358
3359
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3360
            else:
3361
                if has_labels or loss_without_labels:
3362
                    with self.compute_loss_context_manager():
3363
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3364
                    loss = loss.mean().detach()
3365

Sylvain Gugger's avatar
Sylvain Gugger committed
3366
3367
3368
3369
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3370
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3371
                    loss = None
3372
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3373
3374
3375
3376
3377
3378
3379
3380
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3381
3382
3383
3384

        if prediction_loss_only:
            return (loss, None, None)

3385
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3386
3387
3388
3389
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3390
3391
3392

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3393
3394
3395
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3396
3397

        Args:
3398
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3399
3400
3401
                The inputs and targets of the model.

        Returns:
3402
            `int`: The number of floating-point operations.
3403
        """
3404
3405
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3406
3407
        else:
            return 0
3408

3409
    def init_git_repo(self, at_init: bool = False):
3410
        """
3411
        Initializes a git repo in `self.args.hub_model_id`.
3412
3413
3414
3415
3416
3417

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3418
        """
3419
        if not self.is_world_process_zero():
3420
            return
3421
        if self.args.hub_model_id is None:
3422
            repo_name = Path(self.args.output_dir).absolute().name
3423
3424
        else:
            repo_name = self.args.hub_model_id
3425
3426
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3427

3428
3429
        # Make sure the repo exists.
        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
3430
        try:
3431
            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3432
        except EnvironmentError:
3433
            if self.args.overwrite_output_dir and at_init:
3434
3435
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
3436
                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3437
3438
3439
3440
            else:
                raise

        self.repo.git_pull()
3441
3442

        # By default, ignore the checkpoint folders
3443
3444
3445
3446
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3447
3448
3449
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3450
3451
3452
3453
        # Add "*.sagemaker" to .gitignore if using SageMaker
        if os.environ.get("SM_TRAINING_ENV"):
            self._add_sm_patterns_to_gitignore()

3454
3455
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3456
3457
3458
3459
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3460
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3461
3462
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3463
3464
3465
3466
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3467
    ):
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3493
3494
3495
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3496
3497
3498
3499
3500
3501
3502
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3503
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3504
3505
3506
3507
3508
3509
3510
3511
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
3522
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3546
3547
3548
3549
            push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
            # Return type of `Repository.push_to_hub` is either None or a tuple.
            if push_work is not None:
                self.push_in_progress = push_work[1]
3550
3551
        except Exception as e:
            logger.error(f"Error when pushing to hub: {e}")
3552
3553
3554
3555
3556
3557
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3558
        """
3559
        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3560
3561

        Parameters:
3562
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3563
                Message to commit while pushing.
3564
3565
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3566
            kwargs:
3567
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3568
3569

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3570
3571
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3572
        """
3573
3574
3575
3576
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3577

3578
3579
        model_name = kwargs.pop("model_name", None)
        if model_name is None and self.args.should_save:
3580
3581
3582
3583
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3584

3585
3586
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3587
        self.save_model(_internal_call=True)
3588
3589
3590
3591
3592

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3593
3594
3595
3596
3597
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3598
3599
3600
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3601
3602
3603
3604
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3605
3606
3607
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3608
3609
3610
3611
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3612

3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3624
    ) -> EvalLoopOutput:
3625
        """
3626
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3627
3628
3629

        Works both with or without labels.
        """
3630
3631
        args = self.args

3632
3633
3634
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3635
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3636

3637
        # if eval is called w/o train, handle model prep here
3638
        if self.is_deepspeed_enabled and self.deepspeed is None:
3639
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
3640

3641
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3642

3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

3661
3662
3663
3664
3665
3666
3667
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3668
3669
3670
3671
3672
3673
3674
3675
3676

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3677
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3678

3679
        world_size = max(1, args.world_size)
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3690
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3691
3692
3693

        model.eval()

3694
        if args.past_index >= 0:
3695
3696
3697
3698
3699
3700
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3701
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3702

3703
3704
3705
3706
3707
3708
3709
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3710
3711
3712
3713
3714
3715
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3716
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3717
3718

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3719
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3720
3721
3722
3723
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3724
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3725
3726

                # Set back to None to begin a new accumulation
3727
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3728

3729
        if args.past_index and hasattr(self, "_past"):
3730
3731
3732
3733
3734
3735
3736
3737
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3738
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3739
3740
3741
3742

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3743
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3744
3745

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3746
3747
3748
3749
3750
3751
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
3764
3765
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

3766
        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3779
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3780
3781
3782
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821

    def _add_sm_patterns_to_gitignore(self) -> None:
        """Add SageMaker Checkpointing patterns to .gitignore file."""
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return

        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Get current .gitignore content
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Add the patterns to .gitignore
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write the .gitignore file if it has changed
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        self.repo.git_add(".gitignore")

        # avoid race condition with git status
        time.sleep(0.5)

        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848

    def create_accelerator_and_postprocess(self):
        # create accelerator object
        self.accelerator = Accelerator(
            deepspeed_plugin=self.args.deepspeed_plugin,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
        )

        # deepspeed and accelerate flags covering both trainer args and accelerate launcher
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

        # post accelerator creation setup
        if self.is_fsdp_enabled:
            fsdp_plugin = self.accelerator.state.fsdp_plugin
            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
            fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

        if self.is_deepspeed_enabled:
            if getattr(self.args, "hf_deepspeed_config", None) is None:
                from transformers.deepspeed import HfTrainerDeepSpeedConfig

                ds_plugin = self.accelerator.state.deepspeed_plugin

                ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
                ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
                ds_plugin.hf_ds_config.trainer_config_process(self.args)