trainer.py 181 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34
35


36
# Integrations must be imported before ML frameworks:
37
38
# isort: off
from .integrations import (
39
    get_reporting_integration_callbacks,
40
    hp_params,
41
    is_fairscale_available,
42
)
43

44
45
# isort: on

46
47
import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
48
import torch.distributed as dist
49
from huggingface_hub import Repository, create_repo
50
51
from packaging import version
from torch import nn
52
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
53

54
55
from . import __version__
from .configuration_utils import PretrainedConfig
56
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
57
from .debug_utils import DebugOption, DebugUnderflowOverflow
58
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
59
from .dependency_versions_check import dep_version_check
60
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
Sylvain Gugger's avatar
Sylvain Gugger committed
61
from .modelcard import TrainingSummary
62
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
63
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
64
from .optimization import Adafactor, get_scheduler
Yih-Dar's avatar
Yih-Dar committed
65
from .pytorch_utils import ALL_LAYERNORM_LAYERS
66
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
67
68
69
70
71
72
73
74
75
76
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
77
    DistributedTensorGatherer,
78
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
79
    LabelSmoother,
80
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
81
82
83
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
84
    find_batch_size,
85
    get_model_param_count,
86
    get_module_class_from_name,
87
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
88
89
90
91
92
93
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
94
95
96
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
97
    EvalLoopOutput,
98
    EvalPrediction,
99
    FSDPOption,
100
    HPSearchBackend,
101
102
    HubStrategy,
    IntervalStrategy,
103
    PredictionOutput,
104
    RemoveColumnsCollator,
105
    ShardedDDPOption,
106
    TrainerMemoryTracker,
107
108
    TrainOutput,
    default_compute_objective,
109
    denumpify_detensorize,
110
    enable_full_determinism,
111
    find_executable_batch_size,
112
    get_last_checkpoint,
113
    has_length,
114
    number_of_arguments,
115
    seed_worker,
116
    set_seed,
117
    speed_metrics,
118
)
119
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
120
from .utils import (
121
122
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
123
    CONFIG_NAME,
124
125
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
126
    WEIGHTS_INDEX_NAME,
127
    WEIGHTS_NAME,
128
    can_return_loss,
129
    find_labels,
130
    get_full_repo_name,
131
    is_accelerate_available,
132
133
134
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
135
    is_ipex_available,
136
    is_peft_available,
137
    is_safetensors_available,
138
139
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
140
    is_torch_compile_available,
141
    is_torch_neuroncore_available,
142
143
    is_torch_tpu_available,
    logging,
144
    strtobool,
145
)
146
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
147
148


Sylvain Gugger's avatar
Sylvain Gugger committed
149
DEFAULT_CALLBACKS = [DefaultFlowCallback]
150
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
151

152
153
154
155
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
156

157
158
if is_apex_available():
    from apex import amp
159

160
161
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
162

163
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
164
165
166
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met

167
if is_fairscale_available():
168
    dep_version_check("fairscale")
169
    import fairscale
170
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
171
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
172
    from fairscale.nn.wrap import auto_wrap
173
174
175
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

176

Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
179
180
181
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
182
183

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
184
185
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
186

187

188
189
190
191
if is_safetensors_available():
    import safetensors.torch


192
193
194
195
if is_peft_available():
    from peft import PeftModel


196
if is_accelerate_available():
197
    from accelerate import Accelerator, skip_first_batches
198
    from accelerate import __version__ as accelerate_version
199
    from accelerate.utils import DistributedDataParallelKwargs
200

201
202
203
204
205
206
207
208
    if version.parse(accelerate_version) > version.parse("0.20.3"):
        from accelerate.utils import (
            load_fsdp_model,
            load_fsdp_optimizer,
            save_fsdp_model,
            save_fsdp_optimizer,
        )

209

210
211
212
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
213
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
214
215


216
217
218
219
220
221
222
223
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
224
225
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
226
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
227
228

    Args:
229
230
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
231

232
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
233

Sylvain Gugger's avatar
Sylvain Gugger committed
234
235
236
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.
237
238
239
240

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
241
242
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
243
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
244
245
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
246
247
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
248
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
249
250
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
251
252
253
254
255
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
256
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
257
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
258
259
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
260
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
261
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
262
263
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
264
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
265
266
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
267

268
269
270
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
271
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
272
273
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
274
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
275
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
276
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
277

278
279
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
280
281
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
282
283
284
285
286
287
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
288

289
290
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
291
292
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
293
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
294
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
295
296
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
297
298
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
299
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
300
301
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
304

Julien Chaumond's avatar
Julien Chaumond committed
305
306
    """

307
    # Those are used as methods of the Trainer in examples.
308
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
309

Julien Chaumond's avatar
Julien Chaumond committed
310
311
    def __init__(
        self,
312
        model: Union[PreTrainedModel, nn.Module] = None,
313
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
314
315
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
316
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
317
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
318
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
319
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
320
        callbacks: Optional[List[TrainerCallback]] = None,
321
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
322
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
323
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
324
        if args is None:
325
326
327
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
328
329
        self.args = args
        # Seed must be set before instantiating the model when using model
330
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
331
        self.hp_name = None
332
        self.deepspeed = None
333
        self.is_in_train = False
334

335
        self.create_accelerator_and_postprocess()
336

337
338
339
340
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

341
        # set the correct log level depending on the node
342
        log_level = args.get_process_log_level()
343
344
        logging.set_verbosity(log_level)

345
346
347
        # force device and distributed setup init explicitly
        args._setup_devices

348
349
350
351
352
353
354
355
356
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
357
358
359
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
360
361
362
                    FutureWarning,
                )
            self.model_init = model_init
363

364
365
366
367
368
369
370
371
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

372
373
374
375
376
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

377
378
379
380
381
382
        if getattr(model, "hf_device_map", None) is not None:
            devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
            if len(devices) > 1:
                self.is_model_parallel = True
            else:
                self.is_model_parallel = self.args.device != torch.device(devices[0])
383
384
385
386
387
388
389

            # warn users
            logger.info(
                "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
                " to `True` to avoid any unexpected behavior such as device placement mismatching."
            )

390
        # At this stage the model is already loaded
391
392
        if getattr(model, "is_quantized", False):
            if getattr(model, "_is_quantized_training_enabled", False):
393
394
395
396
397
398
399
400
401
402
403
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )
404

405
406
407
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
408
            if self.is_deepspeed_enabled:
409
410
411
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
412
413
414
415
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
416
            if args.parallel_mode != ParallelMode.DISTRIBUTED:
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

432
433
        self.fsdp = None
        if len(args.fsdp) > 0:
434
            if self.is_deepspeed_enabled:
435
436
437
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
438
            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
439
440
                raise ValueError("Using fsdp only works in distributed training.")

441
442
443
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
444
            # below is the current alternative.
445
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
446
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
447

448
            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
449
450
451
452
453

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
454
455
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
456

457
            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
458
            if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
459
460
                "backward_prefetch", []
            ):
461
462
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

Seung-Moo Yang's avatar
Seung-Moo Yang committed
463
464
465
            self.forward_prefetch = False
            if self.args.fsdp_config.get("forward_prefect", False):
                self.forward_prefetch = True
466

467
468
469
470
            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

471
        # one place to sort out whether to place the model on device or not
472
473
474
475
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
476
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
477
        # 4. Sharded DDP - same as MP
478
        # 5. FSDP - same as MP
479
        self.place_model_on_device = args.place_model_on_device
480
481
        if (
            self.is_model_parallel
482
            or self.is_deepspeed_enabled
483
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
484
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
485
            or (self.fsdp is not None)
486
            or self.is_fsdp_enabled
487
        ):
488
489
            self.place_model_on_device = False

490
491
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
492
493
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
494
        self.tokenizer = tokenizer
495

496
        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
Sylvain Gugger's avatar
Sylvain Gugger committed
497
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
498
499
500

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
501
            self.args._n_gpu = 1
502
503
504
505
506

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
507
        self.compute_metrics = compute_metrics
508
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
509
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
510
511
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
512
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
513
514
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
530
        if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
531
532
533
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
534
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
535
536
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
537
538
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
539
540
541
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
542
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
543

544
545
546
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

547
548
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
549
            self.init_git_repo(at_init=True)
550
551
552
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
553
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
554
555
                dist.barrier()

556
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
557
            os.makedirs(self.args.output_dir, exist_ok=True)
558

559
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
560
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
561

562
563
564
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

565
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
566
567
568
569
            raise ValueError(
                "The train_dataset does not implement __len__, max_steps has to be specified. "
                "The number of steps needs to be known in advance for the learning rate scheduler."
            )
570

571
572
573
574
575
576
577
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

578
        self._signature_columns = None
579

580
581
        # Mixed precision setup
        self.use_apex = False
582
583
        self.use_cuda_amp = False
        self.use_cpu_amp = False
584

585
586
587
588
589
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
607

608
        if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
609
            if args.half_precision_backend == "auto":
610
                if args.device == torch.device("cpu"):
611
612
613
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    else:
Yih-Dar's avatar
Yih-Dar committed
614
                        args.half_precision_backend = "cpu_amp"
615
                else:
616
                    args.half_precision_backend = "cuda_amp"
617

618
            logger.info(f"Using {args.half_precision_backend} half precision backend")
619

620
        self.do_grad_scaling = False
621
        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
622
            # deepspeed and SageMaker Model Parallel manage their own half precision
623
624
625
626
627
628
629
630
631
632
633
634
635
            if self.sharded_ddp is not None:
                if args.half_precision_backend == "cuda_amp":
                    self.use_cuda_amp = True
                    self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                    #  bf16 does not need grad scaling
                    self.do_grad_scaling = self.amp_dtype == torch.float16
                    if self.do_grad_scaling:
                        if self.sharded_ddp is not None:
                            self.scaler = ShardedGradScaler()
                        elif self.fsdp is not None:
                            from torch.distributed.fsdp.sharded_grad_scaler import (
                                ShardedGradScaler as FSDPShardedGradScaler,
                            )
636

637
638
639
                            self.scaler = FSDPShardedGradScaler()
                        elif is_torch_tpu_available():
                            from torch_xla.amp import GradScaler
640

641
642
643
644
645
646
647
                            self.scaler = GradScaler()
                        else:
                            self.scaler = torch.cuda.amp.GradScaler()
                elif args.half_precision_backend == "cpu_amp":
                    self.use_cpu_amp = True
                    self.amp_dtype = torch.bfloat16
            elif args.half_precision_backend == "apex":
648
649
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
650
651
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
652
653
654
                    )
                self.use_apex = True

655
656
657
658
659
660
661
662
663
664
665
666
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
667
668
669
670
671
672
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

673
674
675
676
677
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
678
        self.control = TrainerControl()
679
680
681
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
682
        self.hp_search_backend = None
683
        self.use_tune_checkpoints = False
684
        default_label_names = find_labels(self.model.__class__)
685
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
686
        self.can_return_loss = can_return_loss(self.model.__class__)
Sylvain Gugger's avatar
Sylvain Gugger committed
687
688
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

689
690
691
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

692
693
694
        # very last
        self._memory_tracker.stop_and_update_metrics()

695
696
        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
697
            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
698

Sylvain Gugger's avatar
Sylvain Gugger committed
699
700
    def add_callback(self, callback):
        """
701
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
702
703

        Args:
704
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
705
706
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
707
708
709
710
711
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
712
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
713

714
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
715
716

        Args:
717
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
718
719
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
720
721

        Returns:
722
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
723
724
725
726
727
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
728
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
729
730

        Args:
731
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
732
733
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
734
735
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
736

Sylvain Gugger's avatar
Sylvain Gugger committed
737
738
739
740
741
742
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

743
    def _set_signature_columns_if_needed(self):
744
745
746
747
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
748
749
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
750

751
752
753
754
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
755
        signature_columns = self._signature_columns
756
757

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
758
        if len(ignored_columns) > 0:
759
            dset_description = "" if description is None else f"in the {description} set"
760
761
762
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
763
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
764
                " you can safely ignore this message."
765
            )
766

767
        columns = [k for k in signature_columns if k in dataset.column_names]
768

769
770
771
772
773
774
775
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
776

777
778
779
780
781
782
783
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
784
        signature_columns = self._signature_columns
785
786
787
788
789
790
791
792
793
794

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

795
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
796
        if self.train_dataset is None or not has_length(self.train_dataset):
797
            return None
798
799
800

        # Build the sampler.
        if self.args.group_by_length:
801
802
803
804
805
806
807
808
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
809
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
810
811
812
813
814
815
            return LengthGroupedSampler(
                self.args.train_batch_size * self.args.gradient_accumulation_steps,
                dataset=self.train_dataset,
                lengths=lengths,
                model_input_name=model_input_name,
            )
816
817

        else:
818
            return RandomSampler(self.train_dataset)
819
820
821

    def get_train_dataloader(self) -> DataLoader:
        """
822
        Returns the training [`~torch.utils.data.DataLoader`].
823

824
825
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
826
827
828
829
830

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
831

832
        train_dataset = self.train_dataset
833
        data_collator = self.data_collator
834
835
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
836
837
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
838

839
840
841
842
843
844
        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }
845

846
847
848
849
        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker
850

851
        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
Julien Chaumond's avatar
Julien Chaumond committed
852

853
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
873
            return None
Lysandre Debut's avatar
Lysandre Debut committed
874

Julien Chaumond's avatar
Julien Chaumond committed
875
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
876
        """
877
        Returns the evaluation [`~torch.utils.data.DataLoader`].
878

879
880
        Subclass and override this method if you want to inject some custom behavior.

881
        Args:
882
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
883
884
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
885
        """
Julien Chaumond's avatar
Julien Chaumond committed
886
887
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
888
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
889
        data_collator = self.data_collator
890

891
892
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
893
894
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
895

896
897
898
899
900
901
        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }
902

903
904
905
        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
906

907
        return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
Julien Chaumond's avatar
Julien Chaumond committed
908
909

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
910
        """
911
        Returns the test [`~torch.utils.data.DataLoader`].
912

913
914
        Subclass and override this method if you want to inject some custom behavior.

915
        Args:
916
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
917
918
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
919
        """
920
921
        data_collator = self.data_collator

922
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
923
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
924
925
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
926

927
928
929
930
931
932
        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }
933

934
935
936
        if not isinstance(test_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
Lysandre Debut's avatar
Lysandre Debut committed
937

938
        # We use the same batch_size as for eval.
939
        return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
Lysandre Debut's avatar
Lysandre Debut committed
940

941
    def create_optimizer_and_scheduler(self, num_training_steps: int):
942
943
944
        """
        Setup the optimizer and the learning rate scheduler.

945
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
946
947
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
948
949
        """
        self.create_optimizer()
950
951
952
953
954
955
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
956
957
958
959
960

    def create_optimizer(self):
        """
        Setup the optimizer.

961
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
962
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
963
        """
964
965
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

966
        if self.optimizer is None:
967
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
968
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
969
970
            optimizer_grouped_parameters = [
                {
971
972
973
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
974
975
976
                    "weight_decay": self.args.weight_decay,
                },
                {
977
978
979
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
980
981
982
                    "weight_decay": 0.0,
                },
            ]
983
984
985

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

986
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
987
988
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
989
990
                    optim=optimizer_cls,
                    **optimizer_kwargs,
991
992
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
993
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
994
995
996
997
998
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

Stas Bekman's avatar
Stas Bekman committed
999
                    skipped = 0
1000
                    for module in opt_model.modules():
1001
                        if isinstance(module, nn.Embedding):
1002
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
1003
                            logger.info(f"skipped {module}: {skipped/2**20}M params")
1004
1005
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
1006
                    logger.info(f"skipped: {skipped/2**20}M params")
Sylvain Gugger's avatar
Sylvain Gugger committed
1007

Sylvain Gugger's avatar
Sylvain Gugger committed
1008
1009
1010
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1011
1012
        return self.optimizer

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
1023
1024
1025
1026
1027
1028
1029
1030

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

1031
        optimizer_kwargs = {"lr": args.learning_rate}
1032

1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1045
        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
1046
1047
1048
1049
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1050
1051
            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
                optimizer_kwargs.update({"fused": True})
1052
1053
1054
1055
1056
1057
1058
1059
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1060
1061
1062
1063
1064
1065
1066
1067
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        elif args.optim in [
            OptimizerNames.ADAMW_BNB,
            OptimizerNames.ADAMW_8BIT,
            OptimizerNames.PAGED_ADAMW,
            OptimizerNames.PAGED_ADAMW_8BIT,
            OptimizerNames.LION,
            OptimizerNames.LION_8BIT,
            OptimizerNames.PAGED_LION,
            OptimizerNames.PAGED_LION_8BIT,
        ]:
            try:
                from bitsandbytes.optim import AdamW, Lion

                is_paged = False
                optim_bits = 32
                optimizer_cls = None
                additional_optim_kwargs = adam_kwargs
                if "paged" in args.optim:
                    is_paged = True
                if "8bit" in args.optim:
                    optim_bits = 8
                if "adam" in args.optim:
                    optimizer_cls = AdamW
                elif "lion" in args.optim:
                    optimizer_cls = Lion
                    additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}

                bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
                optimizer_kwargs.update(additional_optim_kwargs)
                optimizer_kwargs.update(bnb_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
1100
1101
1102
1103
1104
1105
1106
1107
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
                        "compensation_buffer_dtype": getattr(
                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
                        ),
                    }
                )
            except ImportError:
                raise ValueError("Please install https://github.com/pytorch/torchdistx")
1128
1129
1130
1131
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1132
1133
1134
1135
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1136
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1137
        """
1138
1139
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1140
1141
1142
1143

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1144
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1145
1146
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1147
                optimizer=self.optimizer if optimizer is None else optimizer,
1148
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1149
                num_training_steps=num_training_steps,
1150
            )
1151
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1152

1153
    def num_examples(self, dataloader: DataLoader) -> int:
1154
        """
1155
1156
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1157
        """
1158
        try:
1159
1160
1161
1162
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1163
1164
1165
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1166

1167
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1168
        """HP search setup code"""
1169
1170
        self._trial = trial

1171
1172
        if self.hp_search_backend is None or trial is None:
            return
1173
1174
1175
1176
1177
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1178
1179
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1180
1181
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1182

1183
1184
        for key, value in params.items():
            if not hasattr(self.args, key):
1185
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1186
1187
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1188
                )
1189
                continue
1190
1191
1192
1193
1194
1195
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1196
            logger.info(f"Trial: {trial.params}")
1197
1198
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1199
1200
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1201
1202
1203
        if self.is_deepspeed_enabled:
            if self.args.deepspeed is None:
                raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
1204
            # Rebuild the deepspeed config to reflect the updated training parameters
1205
1206
            from accelerate.utils import DeepSpeedPlugin

1207
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1208

1209
1210
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1211
1212
            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
        self.create_accelerator_and_postprocess()
1213

1214
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1215
1216
        if self.hp_search_backend is None or trial is None:
            return
1217
        self.objective = self.compute_objective(metrics.copy())
1218
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1219
1220
            import optuna

1221
            trial.report(self.objective, step)
1222
            if trial.should_prune():
1223
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1224
1225
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1226
1227
            from ray import tune

1228
            if self.control.should_save:
1229
                self._tune_save_checkpoint()
1230
1231
            tune.report(objective=self.objective, **metrics)

1232
    def _tune_save_checkpoint(self):
1233
1234
        from ray import tune

1235
1236
        if not self.use_tune_checkpoints:
            return
1237
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1238
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1239
            self.save_model(output_dir, _internal_call=True)
1240
            if self.args.should_save:
1241
1242
1243
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1244

1245
    def call_model_init(self, trial=None):
1246
        model_init_argcount = number_of_arguments(self.model_init)
1247
1248
1249
1250
1251
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1252
1253
1254
1255
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1256
1257
1258

        return model

1259
1260
1261
1262
1263
1264
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            example_batch = next(iter(dataloader))
1265
            example_batch = self._prepare_inputs(example_batch)
1266
1267
            try:
                jit_model = model.eval()
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
                with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
1285
                jit_model = torch.jit.freeze(jit_model)
1286
1287
1288
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
1289
                model = jit_model
1290
1291
1292
                self.use_cpu_amp = False
                self.use_cuda_amp = False
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
1293
1294
1295
1296
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1297
1298
1299
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1300
1301
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1302
1303
1304
1305
1306
1307
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
1308
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
1309
            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
1310
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
1311
1312
1313
        else:
            if not model.training:
                model.train()
1314
1315
1316
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1317
1318
1319

        return model

1320
    def _wrap_model(self, model, training=True, dataloader=None):
1321
1322
1323
1324
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

Sylvain Gugger's avatar
Sylvain Gugger committed
1325
1326
1327
1328
1329
1330
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1331
1332
1333
1334
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1335
1336
1337
1338
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

1339
1340
        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
1341
            model = nn.DataParallel(model)
1342

1343
        if self.args.jit_mode_eval:
1344
            start_time = time.time()
1345
            model = self.torch_jit_model_eval(model, dataloader, training)
1346
            self.jit_compilation_time = round(time.time() - start_time, 4)
1347

1348
1349
1350
1351
1352
1353
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1354
1355
1356
1357
1358
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1359
                mixed_precision = self.args.fp16 or self.args.bf16
1360
1361
1362
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1363
1364
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1365
                self.model = model = FullyShardedDDP(
1366
1367
1368
1369
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1370
                ).to(self.args.device)
1371
        # Distributed training using PyTorch FSDP
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
        elif self.fsdp is not None and self.args.fsdp_config["xla"]:
            try:
                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                from torch_xla.distributed.fsdp import checkpoint_module
                from torch_xla.distributed.fsdp.wrap import (
                    size_based_auto_wrap_policy,
                    transformer_auto_wrap_policy,
                )
            except ImportError:
                raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
            auto_wrap_policy = None
            auto_wrapper_callable = None
            if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                auto_wrap_policy = functools.partial(
                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                )
            elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                transformer_cls_to_wrap = set()
                for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                    transformer_cls = get_module_class_from_name(model, layer_class)
                    if transformer_cls is None:
                        raise Exception("Could not find the transformer layer class to wrap in the model.")
                    else:
                        transformer_cls_to_wrap.add(transformer_cls)
                auto_wrap_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    # Transformer layer class to wrap
                    transformer_layer_cls=transformer_cls_to_wrap,
1400
                )
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
            fsdp_kwargs = self.args.xla_fsdp_config
            if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                def auto_wrapper_callable(m, *args, **kwargs):
                    return FSDP(checkpoint_module(m), *args, **kwargs)

            # Wrap the base model with an outer FSDP wrapper
            self.model = model = FSDP(
                model,
                auto_wrap_policy=auto_wrap_policy,
                auto_wrapper_callable=auto_wrapper_callable,
                **fsdp_kwargs,
            )
1414

1415
1416
1417
1418
1419
1420
1421
            # Patch `xm.optimizer_step` should not reduce gradients in this case,
            # as FSDP does not need gradient reduction over sharded parameters.
            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                loss = optimizer.step(**optimizer_args)
                if barrier:
                    xm.mark_step()
                return loss
1422

1423
            xm.optimizer_step = patched_optimizer_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1424
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1425
1426
1427
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1428
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
1429
1430
            if is_torch_neuroncore_available():
                return model
1431
            kwargs = {}
1432
            if self.args.ddp_find_unused_parameters is not None:
1433
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1434
1435
1436
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1437
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1438
            else:
1439
1440
1441
1442
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1443

1444
1445
1446
            if self.args.ddp_broadcast_buffers is not None:
                kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers

1447
            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
1448
1449
1450

        return model

1451
1452
    def train(
        self,
1453
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1454
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1455
        ignore_keys_for_eval: Optional[List[str]] = None,
1456
        **kwargs,
1457
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1458
1459
1460
1461
        """
        Main training entry point.

        Args:
1462
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1463
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1464
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1465
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1466
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1467
                The trial run or the hyperparameter dictionary for hyperparameter search.
1468
            ignore_keys_for_eval (`List[str]`, *optional*)
1469
1470
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1471
1472
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1473
        """
1474
1475
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1476
1477
1478
1479

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1480
1481
        args = self.args

1482
1483
        self.is_in_train = True

1484
1485
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1486
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1487
            self._move_model_to_device(self.model, args.device)
1488

1489
1490
1491
1492
1493
1494
1495
1496
1497
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1498
1499
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1500
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1501

1502
        # Model re-init
1503
        model_reloaded = False
1504
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1505
            # Seed must be set before instantiating the model when using model_init.
1506
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1507
1508
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1509
1510
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1511

1512
        # Load potential model checkpoint
1513
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1514
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1515
            if resume_from_checkpoint is None:
1516
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1517

1518
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
1519
            self._load_from_checkpoint(resume_from_checkpoint)
1520

1521
1522
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1523
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1524
                self._move_model_to_device(self.model, args.device)
1525
1526
            self.model_wrapped = self.model

1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
1540
        self.accelerator.free_memory()
1541
        self._train_batch_size = batch_size
1542
        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
1543
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1544
        train_dataloader = self.get_train_dataloader()
1545
1546
1547
1548
1549

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1550
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1551
1552
1553
1554
1555

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1556
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1557
            num_examples = self.num_examples(train_dataloader)
1558
1559
1560
1561
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1562
                )
1563
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1564
1565
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1566
            else:
1567
1568
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1569
1570
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1571
            max_steps = args.max_steps
1572
1573
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1574
            num_update_steps_per_epoch = max_steps
1575
            num_examples = total_train_batch_size * args.max_steps
1576
            num_train_samples = args.max_steps * total_train_batch_size
1577
1578
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1579
1580
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1581
            )
Julien Chaumond's avatar
Julien Chaumond committed
1582

1583
1584
1585
1586
1587
1588
1589
1590
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps and args.logging_steps < 1:
            args.logging_steps = math.ceil(max_steps * args.logging_steps)
        if args.eval_steps and args.eval_steps < 1:
            args.eval_steps = math.ceil(max_steps * args.eval_steps)
        if args.save_steps and args.save_steps < 1:
            args.save_steps = math.ceil(max_steps * args.save_steps)

1591
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1592
1593
1594
1595
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1596
1597
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1598
1599
1600
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1601

1602
        delay_optimizer_creation = (
1603
1604
1605
1606
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1607
        )
1608
1609
1610
1611
1612

        if self.is_deepspeed_enabled:
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

        if not delay_optimizer_creation:
1613
1614
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1615
        self.state = TrainerState()
1616
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1617

1618
1619
1620
1621
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1622
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1623

1624
1625
1626
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1627
1628
1629
1630
        # as the model is wrapped, don't use `accelerator.prepare`
        # this is for unhandled cases such as
        # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
        use_accelerator_prepare = True if model is self.model else False
1631

1632
1633
1634
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1635
        # prepare using `accelerator` prepare
1636
        if use_accelerator_prepare:
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
            if hasattr(self.lr_scheduler, "step"):
                if self.use_apex:
                    model = self.accelerator.prepare(self.model)
                else:
                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
            else:
                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                    self.model, self.optimizer, self.lr_scheduler
                )
1647

1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
        if self.is_fsdp_enabled:
            self.model = model

        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model_wrapped

        # deepspeed ckpt loading
        if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
            deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)

1663
1664
1665
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1666
1667
        # important: at this point:
        # self.model         is the Transformers Model
1668
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1669

Julien Chaumond's avatar
Julien Chaumond committed
1670
1671
        # Train!
        logger.info("***** Running training *****")
1672
1673
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
1674
        logger.info(f"  Instantaneous batch size per device = {self._train_batch_size:,}")
1675
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
1676
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1677
1678
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
Julien Chaumond's avatar
Julien Chaumond committed
1679

1680
        self.state.epoch = 0
1681
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1682
1683
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1684
        steps_trained_progress_bar = None
1685

Julien Chaumond's avatar
Julien Chaumond committed
1686
        # Check if continuing training from a checkpoint
1687
        if resume_from_checkpoint is not None and os.path.isfile(
1688
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1689
        ):
1690
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1691
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1692
            if not args.ignore_data_skip:
1693
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1694
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1695
1696
            else:
                steps_trained_in_current_epoch = 0
1697
1698

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1699
1700
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1701
            if not args.ignore_data_skip:
1702
1703
1704
1705
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first"
                    f" {steps_trained_in_current_epoch} batches in the first epoch."
                )
1706

Sylvain Gugger's avatar
Sylvain Gugger committed
1707
1708
1709
1710
1711
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1712
1713
1714
1715
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
1716
1717
1718
1719
1720
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1721
1722
1723
1724
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1725
1726
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1727

1728
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1729
        tr_loss = torch.tensor(0.0).to(args.device)
1730
1731
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1732
        self._globalstep_last_logged = self.state.global_step
1733
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1734

1735
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1736

1737
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1738
        if not args.ignore_data_skip:
1739
            for epoch in range(epochs_trained):
1740
1741
                for _ in train_dataloader:
                    break
1742

1743
        total_batched_samples = 0
1744
        for epoch in range(epochs_trained, num_train_epochs):
1745
            epoch_iterator = train_dataloader
1746

1747
            # Reset the past mems state at the beginning of each epoch if necessary.
1748
            if args.past_index >= 0:
1749
1750
                self._past = None

1751
            steps_in_epoch = (
1752
1753
1754
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1755
            )
1756
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1757

1758
1759
1760
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1761
            rng_to_sync = False
1762
            steps_skipped = 0
1763
            if steps_trained_in_current_epoch > 0:
1764
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
1765
                steps_skipped = steps_trained_in_current_epoch
1766
1767
1768
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

1769
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1770
            for step, inputs in enumerate(epoch_iterator):
1771
                total_batched_samples += 1
1772
1773
1774
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False
Julien Chaumond's avatar
Julien Chaumond committed
1775
1776
1777
1778

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1779
1780
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1781
1782
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1783
                    continue
1784
1785
1786
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1787

1788
1789
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1790

1791
                with self.accelerator.accumulate(model):
1792
1793
                    tr_loss_step = self.training_step(model, inputs)

1794
1795
1796
1797
1798
1799
1800
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1801
1802
1803
                else:
                    tr_loss += tr_loss_step

1804
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1805

1806
1807
1808
                # should this be under the accumulate context manager?
                # the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered
                # in accelerate
1809
                if total_batched_samples % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1810
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1811
                    steps_in_epoch <= args.gradient_accumulation_steps
1812
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1813
                ):
1814
                    # Gradient clipping
1815
                    if args.max_grad_norm is not None and args.max_grad_norm > 0:
1816
1817
                        # deepspeed does its own clipping

1818
                        if self.do_grad_scaling:
1819
1820
1821
1822
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1823
1824
1825
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1826
1827
1828
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1829
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1830
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1831
1832
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1833
                            model.clip_grad_norm_(args.max_grad_norm)
1834
                        elif self.use_apex:
1835
                            # Revert to normal clipping otherwise, handling Apex or full precision
1836
                            nn.utils.clip_grad_norm_(
1837
1838
1839
1840
1841
1842
                                amp.master_params(self.optimizer),
                                args.max_grad_norm,
                            )
                        else:
                            self.accelerator.clip_grad_norm_(
                                model.parameters(),
1843
                                args.max_grad_norm,
1844
1845
1846
                            )

                    # Optimizer step
1847
                    optimizer_was_run = True
1848
                    if is_torch_tpu_available():
1849
1850
1851
1852
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
1853
1854
                            # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step
                            self.optimizer.step()
1855
                    elif self.do_grad_scaling:
1856
                        scale_before = self.scaler.get_scale()
1857
                        self.scaler.step(self.optimizer)
1858
                        self.scaler.update()
1859
1860
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1861
                    else:
1862
                        self.optimizer.step()
1863
                        optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
Lysandre Debut's avatar
Lysandre Debut committed
1864

1865
                    if optimizer_was_run:
1866
1867
1868
                        # Delay optimizer scheduling until metrics are generated
                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                            self.lr_scheduler.step()
1869

1870
                    model.zero_grad()
1871
                    self.state.global_step += 1
1872
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
1873
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1874

1875
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
1876
1877
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1878

Sylvain Gugger's avatar
Sylvain Gugger committed
1879
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1880
                    break
1881
1882
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1883
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
1884
1885
1886
1887
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
1888

1889
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
1890
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
1891

1892
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
1893
1894
1895
1896
1897
1898
1899
1900
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1901
            if self.control.should_training_stop:
1902
                break
Julien Chaumond's avatar
Julien Chaumond committed
1903

1904
        if args.past_index and hasattr(self, "_past"):
1905
1906
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1907
1908

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
1909
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
1910
1911
1912
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
1913
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
1914
                dist.barrier()
1915
1916
            elif is_sagemaker_mp_enabled():
                smp.barrier()
1917

1918
            self._load_best_model()
1919

1920
1921
1922
1923
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

1924
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
1925
1926
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
1927
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
1928

1929
        self.is_in_train = False
1930

1931
1932
        self._memory_tracker.stop_and_update_metrics(metrics)

1933
1934
        self.log(metrics)

raghavanone's avatar
raghavanone committed
1935
1936
1937
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

1938
1939
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
raghavanone's avatar
raghavanone committed
1940
1941
1942
1943
1944
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

1945
1946
1947
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
1948

raghavanone's avatar
raghavanone committed
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

1969
1970
1971
1972
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

1973
        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
1974
1975
        adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
        adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
1976
1977
1978
1979
1980
1981
        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

        if not any(
1982
1983
1984
1985
1986
1987
1988
1989
1990
            os.path.isfile(f)
            for f in [
                weights_file,
                safe_weights_file,
                weights_index_file,
                safe_weights_index_file,
                adapter_weights_file,
                adapter_safe_weights_file,
            ]
1991
        ):
1992
1993
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

1994
        logger.info(f"Loading model from {resume_from_checkpoint}.")
1995

1996
1997
        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
1998
1999
2000
2001
2002
2003
2004
2005
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

2006
        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
2007
            # If the model is on the GPU, it still works!
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
2022
                    state_dict = torch.load(weights_file, map_location="cpu")
2023
2024
2025
2026
2027
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
2028
            elif self.is_fsdp_enabled:
2029
                load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)
2030
2031
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
2032
2033
2034
2035
2036
                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
                else:
                    state_dict = torch.load(weights_file, map_location="cpu")

2037
2038
2039
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
2040
2041
                # release memory
                del state_dict
2042
                self._issue_warnings_after_load(load_result)
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057

        # Load adapters following PR # 24096
        elif is_peft_available() and isinstance(model, PeftModel):
            # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
            if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
                if os.path.exists(resume_from_checkpoint):
                    model.load_adapter(resume_from_checkpoint, model.active_adapter)
                else:
                    logger.warning(
                        "The intermediate checkpoints of PEFT may not be saved correctly, "
                        f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
                        "Check some examples here: https://github.com/huggingface/peft/issues/96"
                    )
            else:
                logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
2058
2059
        else:
            # We load the sharded checkpoint
2060
2061
2062
            load_result = load_sharded_checkpoint(
                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
            )
2063
            if not is_sagemaker_mp_enabled():
2064
                self._issue_warnings_after_load(load_result)
2065
2066
2067
2068

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
2069
        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
2070
2071
2072
        best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
        best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)

2073
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2074
2075
2076
2077
2078
2079
        if (
            os.path.exists(best_model_path)
            or os.path.exists(best_safe_model_path)
            or os.path.exists(best_adapter_model_path)
            or os.path.exists(best_safe_adapter_model_path)
        ):
2080
2081
            if self.is_deepspeed_enabled:
                deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
2082
            else:
2083
                has_been_loaded = True
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
2097
2098
2099
2100
2101
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")

2102
2103
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
2104
                elif self.is_fsdp_enabled:
2105
2106
                    load_fsdp_model(
                        self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
2107
                    )
2108
                else:
2109
2110
                    if is_peft_available() and isinstance(model, PeftModel):
                        # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
2111
                        if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
2112
                            if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
2113
2114
2115
2116
2117
2118
2119
2120
                                model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
                                # Load_adapter has no return value present, modify it when appropriate.
                                from torch.nn.modules.module import _IncompatibleKeys

                                load_result = _IncompatibleKeys([], [])
                            else:
                                logger.warning(
                                    "The intermediate checkpoints of PEFT may not be saved correctly, "
2121
2122
                                    f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
                                    "Check some examples here: https://github.com/huggingface/peft/issues/96"
2123
                                )
2124
                                has_been_loaded = False
2125
                        else:
2126
2127
                            logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
                            has_been_loaded = False
2128
                    else:
2129
2130
2131
2132
2133
                        # We load the model state dict on the CPU to avoid an OOM error.
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")
2134

2135
2136
2137
2138
                        # If the model is on the GPU, it still works!
                        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                        # which takes *args instead of **kwargs
                        load_result = model.load_state_dict(state_dict, False)
2139
                if not is_sagemaker_mp_enabled() and has_been_loaded:
2140
                    self._issue_warnings_after_load(load_result)
2141
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2142
2143
2144
2145
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2146
                self._issue_warnings_after_load(load_result)
2147
2148
2149
2150
2151
2152
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2153
    def _issue_warnings_after_load(self, load_result):
2154
        if len(load_result.missing_keys) != 0:
2155
2156
2157
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2158
2159
                self.model.tie_weights()
            else:
2160
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2161
        if len(load_result.unexpected_keys) != 0:
2162
2163
2164
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2165

2166
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2167
        if self.control.should_log:
2168
2169
2170
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2171
            logs: Dict[str, float] = {}
2172
2173
2174
2175

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2176
2177
2178
            # reset tr_loss to zero
            tr_loss -= tr_loss

2179
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2180
            logs["learning_rate"] = self._get_learning_rate()
2181

2182
            self._total_loss_scalar += tr_loss_scalar
2183
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2184
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2185
2186
2187
2188
2189

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2190
            if isinstance(self.eval_dataset, dict):
2191
                metrics = {}
2192
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
2193
                    dataset_metrics = self.evaluate(
2194
2195
2196
2197
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
2198
                    metrics.update(dataset_metrics)
2199
2200
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2201
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2202

2203
2204
            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
2205
2206
2207
2208
                metric_to_check = self.args.metric_for_best_model
                if not metric_to_check.startswith("eval_"):
                    metric_to_check = f"eval_{metric_to_check}"
                self.lr_scheduler.step(metrics[metric_to_check])
2209

Sylvain Gugger's avatar
Sylvain Gugger committed
2210
2211
2212
2213
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2214
2215
2216
2217
2218
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2219
2220
2221
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2222
            if not os.path.isfile(rng_file):
2223
                logger.info(
2224
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2225
2226
2227
2228
2229
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2230
            if not os.path.isfile(rng_file):
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
2242
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2243
                torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
2244
            else:
2245
                try:
2246
                    torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
2247
                except Exception as e:
2248
                    logger.info(
2249
2250
2251
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2252
2253
2254
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2255
    def _save_checkpoint(self, model, trial, metrics=None):
2256
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2257
        # want to save except FullyShardedDDP.
2258
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2259

2260
        # Save model checkpoint
2261
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2262

raghavanone's avatar
raghavanone committed
2263
        if self.hp_search_backend is None and trial is None:
2264
            self.store_flos()
2265

raghavanone's avatar
raghavanone committed
2266
        run_dir = self._get_output_dir(trial=trial)
2267
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2268
        self.save_model(output_dir, _internal_call=True)
2269
        if self.is_deepspeed_enabled:
2270
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2271
            # config `stage3_gather_16bit_weights_on_model_save` is True
2272
            self.model_wrapped.save_checkpoint(output_dir)
2273
2274

        # Save optimizer and scheduler
2275
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2276
            self.optimizer.consolidate_state_dict()
2277

2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
        if self.fsdp or self.is_fsdp_enabled:
            if self.is_fsdp_enabled:
                save_fsdp_optimizer(
                    self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
                )
            else:
                # FSDP has a different interface for saving optimizer states.
                # Needs to be called on all ranks to gather all states.
                # full_optim_state_dict will be deprecated after Pytorch 2.2!
                full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
Qingyang Wu's avatar
Qingyang Wu committed
2288

2289
2290
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2291
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2292
            with warnings.catch_warnings(record=True) as caught_warnings:
2293
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2294
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2295
        elif is_sagemaker_mp_enabled():
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2311
        elif self.args.should_save and not self.is_deepspeed_enabled:
2312
            # deepspeed.save_checkpoint above saves model/optim/sched
Qingyang Wu's avatar
Qingyang Wu committed
2313
2314
2315
2316
2317
            if self.fsdp:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

2318
            with warnings.catch_warnings(record=True) as caught_warnings:
2319
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2320
            reissue_pt_warnings(caught_warnings)
2321
            if self.do_grad_scaling:
2322
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2323
2324

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2325
        if metrics is not None and self.args.metric_for_best_model is not None:
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2341
        if self.args.should_save:
2342
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2343

2344
2345
2346
2347
2348
2349
2350
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
2351
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2352
2353
2354
2355
2356
2357
2358
2359
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2360
2361
2362
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2363

2364
        if self.args.world_size <= 1:
2365
2366
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2367
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2368

2369
2370
2371
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2372
        # Maybe delete some older checkpoints.
2373
        if self.args.should_save:
2374
2375
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2376
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2377
        """If optimizer and scheduler states exist, load them."""
2378
        if checkpoint is None:
2379
2380
            return

2381
        if self.is_deepspeed_enabled:
2382
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2383
2384
            return

2385
2386
2387
2388
2389
2390
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2391
2392
2393
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2394
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2395
                with warnings.catch_warnings(record=True) as caught_warnings:
2396
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2397
2398
2399
2400
2401
2402
2403
2404
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
2405
                if is_sagemaker_mp_enabled():
2406
2407
2408
2409
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2410

2411
2412
2413
2414
2415
2416
2417
2418
2419
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2420
2421
2422

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
2423
2424
2425
2426
                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
                    if self.fsdp or self.is_fsdp_enabled:
                        if self.is_fsdp_enabled:
                            load_fsdp_optimizer(
                                self.accelerator.state.fsdp_plugin,
                                self.accelerator,
                                self.optimizer,
                                self.model,
                                checkpoint,
                            )
                        else:
                            full_osd = None
                            # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
                            if self.args.process_index == 0:
                                full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
                            # call scatter_full_optim_state_dict on all ranks
                            sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
                            self.optimizer.load_state_dict(sharded_osd)
Qingyang Wu's avatar
Qingyang Wu committed
2444
2445
2446
2447
                    else:
                        self.optimizer.load_state_dict(
                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                        )
Sylvain Gugger's avatar
Sylvain Gugger committed
2448
                with warnings.catch_warnings(record=True) as caught_warnings:
2449
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2450
                reissue_pt_warnings(caught_warnings)
2451
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2452
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2453

2454
2455
2456
2457
2458
2459
2460
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2461
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2462
        **kwargs,
2463
2464
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2465
2466
2467
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2468

2469
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2470

Sylvain Gugger's avatar
Sylvain Gugger committed
2471
2472
2473
2474
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2475
2476

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2477

2478
        Args:
2479
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2480
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2481
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2482
2483
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2484
2485
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2486
            n_trials (`int`, *optional*, defaults to 100):
2487
                The number of trial runs to test.
2488
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2489
2490
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2491
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2492
2493
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2494
2495
2496
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2497
2498
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2499

Sylvain Gugger's avatar
Sylvain Gugger committed
2500
2501
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2502
2503
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2504
2505

        Returns:
2506
2507
            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
            `run_summary` attribute for Ray backend.
2508
2509
2510
2511
        """
        if backend is None:
            backend = default_hp_search_backend()
        backend = HPSearchBackend(backend)
2512
2513
        backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
        backend_obj.ensure_available()
2514
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2515
2516
2517
2518
2519
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2520
        self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
2521
        self.hp_name = hp_name
2522
2523
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2524
        best_run = backend_obj.run(self, n_trials, direction, **kwargs)
2525
2526
2527
2528

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2529
    def log(self, logs: Dict[str, float]) -> None:
2530
        """
2531
        Log `logs` on the various objects watching training.
2532
2533
2534
2535

        Subclass and override this method to inject custom behavior.

        Args:
2536
            logs (`Dict[str, float]`):
2537
2538
                The values to log.
        """
2539
        if self.state.epoch is not None:
2540
            logs["epoch"] = round(self.state.epoch, 2)
2541

2542
2543
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2544
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2545

2546
2547
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2548
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2549
        """
2550
2551
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2552
2553
2554
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
2555
            kwargs = {"device": self.args.device}
2556
            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
2557
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2558
2559
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
2560
                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
2561
2562
2563
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2564
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2565
        """
2566
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2567
2568
        handling potential state.
        """
2569
        inputs = self._prepare_input(inputs)
2570
2571
2572
2573
2574
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2575
2576
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2577

2578
2579
        return inputs

2580
2581
2582
2583
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
2584
        return self.autocast_smart_context_manager()
2585

2586
    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
2587
        """
2588
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2589
2590
        arguments, depending on the situation.
        """
2591
        if self.use_cuda_amp or self.use_cpu_amp:
Yih-Dar's avatar
Yih-Dar committed
2592
2593
2594
2595
2596
            ctx_manager = (
                torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
                if self.use_cpu_amp
                else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
            )
2597
2598
2599
2600
2601
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2602
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2603
        """
2604
        Perform a training step on a batch of inputs.
2605
2606
2607
2608

        Subclass and override to inject custom behavior.

        Args:
2609
            model (`nn.Module`):
2610
                The model to train.
2611
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2612
2613
2614
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2615
                argument `labels`. Check your model's documentation for all accepted arguments.
2616
2617

        Return:
2618
            `torch.Tensor`: The tensor with training loss on this batch.
2619
2620
        """
        model.train()
2621
        inputs = self._prepare_inputs(inputs)
2622

Sylvain Gugger's avatar
Sylvain Gugger committed
2623
        if is_sagemaker_mp_enabled():
2624
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2625
2626
            return loss_mb.reduce_mean().detach().to(self.args.device)

2627
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2628
            loss = self.compute_loss(model, inputs)
2629

Julien Chaumond's avatar
Julien Chaumond committed
2630
2631
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2632

2633
        if self.do_grad_scaling:
2634
            self.scaler.scale(loss).backward()
2635
        elif self.use_apex:
2636
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2637
2638
                scaled_loss.backward()
        else:
2639
            self.accelerator.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2640

2641
        return loss.detach() / self.args.gradient_accumulation_steps
Julien Chaumond's avatar
Julien Chaumond committed
2642

2643
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2644
2645
2646
2647
2648
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2649
2650
2651
2652
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2653
2654
        outputs = model(**inputs)
        # Save past state if it exists
2655
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2656
2657
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2658

2659
        if labels is not None:
2660
2661
2662
2663
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2664
        else:
2665
2666
2667
2668
2669
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2670
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2671
2672
2673
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2674

2675
2676
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2677
2678
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2679
        """
2680
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2681

2682
2683
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2684
        Whether or not this process is the global main process (when training in a distributed fashion on several
2685
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2686
        """
2687
2688
2689
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2690
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2691
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2692
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2693

Sylvain Gugger's avatar
Sylvain Gugger committed
2694
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2695
        """
2696
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2697

2698
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2699
        """
2700
2701
2702
2703

        if output_dir is None:
            output_dir = self.args.output_dir

2704
        if is_torch_tpu_available():
2705
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2706
2707
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2708
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2709
            state_dict = self.model_wrapped.state_dict()
2710
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2711
                self._save(output_dir, state_dict=state_dict)
2712
2713
2714
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2715
        elif (
2716
2717
2718
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2719
            or self.is_fsdp_enabled
2720
        ):
2721
            if self.is_fsdp_enabled:
2722
                os.makedirs(output_dir, exist_ok=True)
2723
                save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
2724
2725
            else:
                state_dict = self.model.state_dict()
2726

2727
2728
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
2729
        elif self.is_deepspeed_enabled:
2730
            # this takes care of everything as long as we aren't under zero3
2731
            if self.args.should_save:
2732
2733
2734
2735
2736
2737
2738
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2739
                if self.args.should_save:
2740
2741
2742
2743
2744
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2745
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2746
2747
                # if false it will not be saved.
                # This must be called on all ranks
2748
                if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
2749
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2750
2751
2752
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2753
                    )
2754
                    self.model_wrapped.save_checkpoint(output_dir)
2755

2756
        elif self.args.should_save:
2757
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2758

Sylvain Gugger's avatar
Sylvain Gugger committed
2759
2760
2761
2762
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2763
2764
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2765
        logger.info(f"Saving model checkpoint to {output_dir}")
2766
2767
2768

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2769
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2770
2771
2772
2773

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2774
        if not isinstance(self.model, PreTrainedModel):
2775
2776
2777
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2778
                    is_main_process=self.args.should_save,
2779
2780
2781
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2782
2783
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2784
2785
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2786
        else:
2787
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2788
        if self.tokenizer is not None and self.args.should_save:
2789
            self.tokenizer.save_pretrained(output_dir)
2790

2791
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2792
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2793
2794
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2795
        logger.info(f"Saving model checkpoint to {output_dir}")
2796
2797

        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
Julien Chaumond's avatar
Julien Chaumond committed
2798
2799
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
2800
        if not isinstance(self.model, supported_classes):
2801
2802
2803
            if state_dict is None:
                state_dict = self.model.state_dict()

2804
            if isinstance(unwrap_model(self.model), supported_classes):
2805
2806
2807
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
2808
2809
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2810
2811
2812
2813
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2814
        else:
2815
2816
2817
2818
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

2819
        if self.tokenizer is not None:
2820
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2821
2822

        # Good practice: save your training arguments together with the trained model
2823
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2824

2825
    def store_flos(self):
2826
        # Storing the number of floating-point operations that went into the model
2827
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2828
2829
2830
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2831
2832
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2833
            self.state.total_flos += self.current_flos
2834
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2835

2836
2837
2838
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2839
2840
        ordering_and_checkpoint_path = []

2841
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2842
2843
2844
2845
2846

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2847
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2848
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2849
2850
2851
2852
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2853
2854
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2855
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2856
2857
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2858
2859
        return checkpoints_sorted

2860
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2861
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2862
2863
2864
            return

        # Check if we should delete older checkpoint(s)
2865
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2866
2867
2868
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2869
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2880
2881
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2882
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
2883
            shutil.rmtree(checkpoint, ignore_errors=True)
Julien Chaumond's avatar
Julien Chaumond committed
2884

2885
    def evaluate(
2886
2887
2888
2889
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2890
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2891
        """
2892
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2893

Sylvain Gugger's avatar
Sylvain Gugger committed
2894
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2895
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2896

2897
2898
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2899
        Args:
2900
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2901
2902
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
2903
                method.
2904
            ignore_keys (`List[str]`, *optional*):
2905
2906
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2907
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2908
2909
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2910

Julien Chaumond's avatar
Julien Chaumond committed
2911
        Returns:
2912
2913
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2914
        """
2915
2916
2917
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2918
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2919
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2920

2921
2922
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2923
2924
2925
2926
2927
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
2928
            ignore_keys=ignore_keys,
2929
            metric_key_prefix=metric_key_prefix,
2930
        )
Lysandre Debut's avatar
Lysandre Debut committed
2931

2932
        total_batch_size = self.args.eval_batch_size * self.args.world_size
2933
2934
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
2935
2936
2937
2938
2939
2940
2941
2942
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2943

2944
        self.log(output.metrics)
2945

2946
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
2947
2948
2949
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
2950
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
2951
2952
2953

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
2954
2955
        return output.metrics

2956
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
2957
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
2958
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2959
        """
2960
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2961

Sylvain Gugger's avatar
Sylvain Gugger committed
2962
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
2963
        will also return metrics, like in `evaluate()`.
2964
2965

        Args:
2966
2967
2968
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
2969
            ignore_keys (`List[str]`, *optional*):
2970
2971
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2972
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
2973
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
2974
                "test_bleu" if the prefix is "test" (default)
2975

2976
2977
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
2978
2979
2980
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
2981

2982
        </Tip>
2983

2984
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
2985

2986
2987
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
2988
2989
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
2990
        """
2991
2992
2993
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2994
        test_dataloader = self.get_test_dataloader(test_dataset)
2995
        start_time = time.time()
2996

2997
2998
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2999
3000
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
3001
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3002
3003
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3004
3005
3006
3007
3008
3009
3010
3011
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3012

3013
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
3014
3015
        self._memory_tracker.stop_and_update_metrics(output.metrics)

3016
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
3017

3018
    def evaluation_loop(
3019
3020
3021
3022
3023
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
3024
        metric_key_prefix: str = "eval",
3025
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3026
        """
3027
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
3028
3029
3030

        Works both with or without labels.
        """
3031
3032
3033
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
3034

3035
        # if eval is called w/o train, handle model prep here
3036
        if self.is_deepspeed_enabled and self.deepspeed is None:
3037
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
3038

3039
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
3040

3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

3059
3060
3061
3062
3063
3064
3065
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3066

3067
        batch_size = self.args.eval_batch_size
3068

3069
        logger.info(f"***** Running {description} *****")
3070
        if has_length(dataloader):
3071
3072
3073
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
3074
        logger.info(f"  Batch size = {batch_size}")
3075

Julien Chaumond's avatar
Julien Chaumond committed
3076
3077
        model.eval()

3078
3079
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
3080
        eval_dataset = getattr(dataloader, "dataset", None)
3081

3082
        if args.past_index >= 0:
3083
            self._past = None
3084

3085
3086
3087
3088
3089
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
3090
3091
        inputs_host = None

3092
3093
3094
3095
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
3096
        all_inputs = None
3097
3098
3099
3100
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
3101
        for step, inputs in enumerate(dataloader):
3102
3103
3104
3105
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
3106
3107
3108
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
3109
3110

            # Prediction step
3111
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3112
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3113

3114
3115
3116
            if is_torch_tpu_available():
                xm.mark_step()

3117
            # Update containers on host
3118
            if loss is not None:
3119
3120
                losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
                losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
3121
            if labels is not None:
3122
                labels = self.accelerator.pad_across_processes(labels)
3123
            if inputs_decode is not None:
3124
3125
                inputs_decode = self.accelerator.pad_across_processes(inputs_decode)
                inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
3126
3127
3128
3129
3130
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3131
            if logits is not None:
3132
                logits = self.accelerator.pad_across_processes(logits)
3133
3134
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
3135
                logits = self.accelerator.gather_for_metrics((logits))
3136
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
3137

3138
            if labels is not None:
3139
                labels = self.accelerator.gather_for_metrics((labels))
3140
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3141

3142
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
3143

3144
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3145
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3146
3147
3148
3149
3150
3151
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3152
3153
3154
3155
3156
3157
3158
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3159
3160
3161
3162
3163
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3164
3165

                # Set back to None to begin a new accumulation
3166
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3167

3168
        if args.past_index and hasattr(self, "_past"):
3169
3170
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3171

3172
        # Gather all remaining tensors and put them back on the CPU
3173
        if losses_host is not None:
3174
            all_losses = nested_numpify(losses_host)
3175
        if preds_host is not None:
3176
            all_preds = nested_numpify(preds_host)
3177
        if inputs_host is not None:
3178
            all_inputs = nested_numpify(inputs_host)
3179
        if labels_host is not None:
3180
            all_labels = nested_numpify(labels_host)
3181
3182

        # Number of samples
3183
        if has_length(eval_dataset):
3184
            num_samples = len(eval_dataset)
3185
3186
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3187
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3188
3189
            num_samples = eval_dataset.num_examples
        else:
3190
3191
3192
3193
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3194
3195
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3196
3197
3198

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3199
3200
3201
3202
3203
3204
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3205
3206
        else:
            metrics = {}
3207

3208
3209
3210
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3211
3212
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3213
3214
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
3215

3216
        # Prefix all keys with metric_key_prefix + '_'
3217
        for key in list(metrics.keys()):
3218
3219
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3220

3221
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3222

3223
    def _nested_gather(self, tensors, name=None):
3224
3225
3226
3227
3228
3229
3230
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3231
3232
            if name is None:
                name = "nested_gather"
3233
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3234
3235
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
Zachary Mueller's avatar
Zachary Mueller committed
3236
3237
3238
        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
            self.args.distributed_state is None and self.local_rank != -1
        ):
3239
            tensors = distributed_concat(tensors)
3240
        return tensors
3241

3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
3264
3265
3266
        # When extracting XLA graphs for compilation, max_size is 0,
        # so use inequality to avoid errors.
        if tensor.shape[1] >= max_size:
3267
3268
3269
3270
3271
3272
3273
3274
3275
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3276

3277
    def prediction_step(
3278
3279
3280
3281
3282
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3283
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3284
        """
Stas Bekman's avatar
Stas Bekman committed
3285
        Perform an evaluation step on `model` using `inputs`.
3286
3287
3288
3289

        Subclass and override to inject custom behavior.

        Args:
3290
            model (`nn.Module`):
3291
                The model to evaluate.
3292
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3293
3294
3295
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3296
3297
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3298
                Whether or not to return the loss only.
3299
            ignore_keys (`List[str]`, *optional*):
3300
3301
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3302
3303

        Return:
3304
3305
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3306
        """
3307
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
3308
3309
3310
3311
3312
3313
3314
3315
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

3316
        inputs = self._prepare_inputs(inputs)
3317
3318
3319
3320
3321
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3322

3323
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
3324
        if has_labels or loss_without_labels:
3325
3326
3327
3328
3329
3330
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3331
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3332
3333
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
3334
                if has_labels or loss_without_labels:
Sylvain Gugger's avatar
Sylvain Gugger committed
3335
3336
3337
3338
3339
3340
3341
3342
3343
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3344
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3345
3346
3347
3348
3349
3350
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3351
            else:
3352
                if has_labels or loss_without_labels:
3353
                    with self.compute_loss_context_manager():
3354
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3355
                    loss = loss.mean().detach()
3356

Sylvain Gugger's avatar
Sylvain Gugger committed
3357
3358
3359
3360
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3361
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3362
                    loss = None
3363
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3364
3365
3366
3367
3368
3369
3370
3371
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3372
3373
3374
3375

        if prediction_loss_only:
            return (loss, None, None)

3376
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3377
3378
3379
3380
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3381
3382
3383

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3384
3385
3386
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3387
3388

        Args:
3389
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3390
3391
3392
                The inputs and targets of the model.

        Returns:
3393
            `int`: The number of floating-point operations.
3394
        """
3395
3396
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3397
3398
        else:
            return 0
3399

3400
    def init_git_repo(self, at_init: bool = False):
3401
        """
3402
        Initializes a git repo in `self.args.hub_model_id`.
3403
3404
3405
3406
3407
3408

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3409
        """
3410
        if not self.is_world_process_zero():
3411
            return
3412
        if self.args.hub_model_id is None:
3413
            repo_name = Path(self.args.output_dir).absolute().name
3414
3415
        else:
            repo_name = self.args.hub_model_id
3416
3417
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3418

3419
3420
        # Make sure the repo exists.
        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
3421
        try:
3422
            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3423
        except EnvironmentError:
3424
            if self.args.overwrite_output_dir and at_init:
3425
3426
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
3427
                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3428
3429
3430
3431
            else:
                raise

        self.repo.git_pull()
3432
3433

        # By default, ignore the checkpoint folders
3434
3435
3436
3437
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3438
3439
3440
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3441
3442
3443
3444
        # Add "*.sagemaker" to .gitignore if using SageMaker
        if os.environ.get("SM_TRAINING_ENV"):
            self._add_sm_patterns_to_gitignore()

3445
3446
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3447
3448
3449
3450
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3451
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3452
3453
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3454
3455
3456
3457
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3458
    ):
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3484
3485
3486
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3487
3488
3489
3490
3491
3492
3493
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3494
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3495
3496
3497
3498
3499
3500
3501
3502
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
3513
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3537
3538
3539
3540
            push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
            # Return type of `Repository.push_to_hub` is either None or a tuple.
            if push_work is not None:
                self.push_in_progress = push_work[1]
3541
3542
        except Exception as e:
            logger.error(f"Error when pushing to hub: {e}")
3543
3544
3545
3546
3547
3548
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3549
        """
3550
        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3551
3552

        Parameters:
3553
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3554
                Message to commit while pushing.
3555
3556
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3557
            kwargs:
3558
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3559
3560

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3561
3562
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3563
        """
3564
3565
3566
3567
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3568

3569
3570
        model_name = kwargs.pop("model_name", None)
        if model_name is None and self.args.should_save:
3571
3572
3573
3574
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3575

3576
3577
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3578
        self.save_model(_internal_call=True)
3579
3580
3581
3582
3583

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3584
3585
3586
3587
3588
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3589
3590
3591
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3592
3593
3594
3595
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3596
3597
3598
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3599
3600
3601
3602
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3603

3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3615
    ) -> EvalLoopOutput:
3616
        """
3617
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3618
3619
3620

        Works both with or without labels.
        """
3621
3622
        args = self.args

3623
3624
3625
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3626
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3627

3628
        # if eval is called w/o train, handle model prep here
3629
        if self.is_deepspeed_enabled and self.deepspeed is None:
3630
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
3631

3632
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3633

3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

3652
3653
3654
3655
3656
3657
3658
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3659
3660
3661
3662
3663
3664
3665
3666
3667

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3668
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3669

3670
        world_size = max(1, args.world_size)
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3681
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3682
3683
3684

        model.eval()

3685
        if args.past_index >= 0:
3686
3687
3688
3689
3690
3691
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3692
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3693

3694
3695
3696
3697
3698
3699
3700
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3701
3702
3703
3704
3705
3706
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3707
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3708
3709

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3710
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3711
3712
3713
3714
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3715
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3716
3717

                # Set back to None to begin a new accumulation
3718
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3719

3720
        if args.past_index and hasattr(self, "_past"):
3721
3722
3723
3724
3725
3726
3727
3728
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3729
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3730
3731
3732
3733

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3734
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3735
3736

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3737
3738
3739
3740
3741
3742
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3743
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

3757
        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3770
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3771
3772
3773
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811
3812

    def _add_sm_patterns_to_gitignore(self) -> None:
        """Add SageMaker Checkpointing patterns to .gitignore file."""
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return

        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Get current .gitignore content
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Add the patterns to .gitignore
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write the .gitignore file if it has changed
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        self.repo.git_add(".gitignore")

        # avoid race condition with git status
        time.sleep(0.5)

        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
3836
3837
3838
3839

    def create_accelerator_and_postprocess(self):
        # create accelerator object
        self.accelerator = Accelerator(
            deepspeed_plugin=self.args.deepspeed_plugin,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
        )

        # deepspeed and accelerate flags covering both trainer args and accelerate launcher
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

        # post accelerator creation setup
        if self.is_fsdp_enabled:
            fsdp_plugin = self.accelerator.state.fsdp_plugin
            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
            fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

        if self.is_deepspeed_enabled:
            if getattr(self.args, "hf_deepspeed_config", None) is None:
                from transformers.deepspeed import HfTrainerDeepSpeedConfig

                ds_plugin = self.accelerator.state.deepspeed_plugin

                ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
                ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
                ds_plugin.hf_ds_config.trainer_config_process(self.args)