trainer.py 186 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34

35
36
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
37

38
# Integrations must be imported before ML frameworks:
39
40
# isort: off
from .integrations import (
41
    default_hp_search_backend,
42
    get_reporting_integration_callbacks,
43
    hp_params,
44
    is_fairscale_available,
45
    is_optuna_available,
46
    is_ray_tune_available,
47
    is_sigopt_available,
48
    is_wandb_available,
49
50
    run_hp_search_optuna,
    run_hp_search_ray,
51
    run_hp_search_sigopt,
52
    run_hp_search_wandb,
53
)
54

55
56
# isort: on

57
58
import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
59
import torch.distributed as dist
60
from huggingface_hub import Repository, create_repo
61
62
from packaging import version
from torch import nn
63
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
64
65
from torch.utils.data.distributed import DistributedSampler

66
67
from . import __version__
from .configuration_utils import PretrainedConfig
68
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
69
from .debug_utils import DebugOption, DebugUnderflowOverflow
70
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
71
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
72
from .modelcard import TrainingSummary
73
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
74
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
75
from .optimization import Adafactor, get_scheduler
76
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
77
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
78
79
80
81
82
83
84
85
86
87
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
88
    DistributedLengthGroupedSampler,
89
    DistributedSamplerWithLoop,
90
    DistributedTensorGatherer,
91
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
92
    LabelSmoother,
93
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
94
    SequentialDistributedSampler,
95
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
    distributed_broadcast_scalars,
    distributed_concat,
98
    find_batch_size,
99
    get_model_param_count,
100
    get_module_class_from_name,
101
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
102
103
104
    nested_concat,
    nested_detach,
    nested_numpify,
105
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
108
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
109
110
111
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
112
    EvalLoopOutput,
113
    EvalPrediction,
114
    FSDPOption,
115
    HPSearchBackend,
116
117
    HubStrategy,
    IntervalStrategy,
118
    PredictionOutput,
119
    RemoveColumnsCollator,
120
    ShardedDDPOption,
121
    TrainerMemoryTracker,
122
123
124
    TrainOutput,
    default_compute_objective,
    default_hp_space,
125
    denumpify_detensorize,
126
    enable_full_determinism,
127
    find_executable_batch_size,
128
    get_last_checkpoint,
129
    has_length,
130
    number_of_arguments,
131
    seed_worker,
132
    set_seed,
133
    speed_metrics,
134
)
135
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
136
137
from .utils import (
    CONFIG_NAME,
138
139
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
140
    WEIGHTS_INDEX_NAME,
141
    WEIGHTS_NAME,
142
    can_return_loss,
143
    find_labels,
144
    get_full_repo_name,
145
    is_accelerate_available,
146
147
148
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
149
    is_ipex_available,
150
    is_safetensors_available,
151
152
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
153
    is_torch_compile_available,
154
    is_torch_neuroncore_available,
155
156
    is_torch_tpu_available,
    logging,
157
    strtobool,
158
)
159
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
160
161


162
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
163

Sylvain Gugger's avatar
Sylvain Gugger committed
164
DEFAULT_CALLBACKS = [DefaultFlowCallback]
165
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
166

167
168
169
170
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
171

172
173
if is_apex_available():
    from apex import amp
174

175
176
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
177

178
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
179
180
181
182
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

183
if is_fairscale_available():
184
    dep_version_check("fairscale")
185
    import fairscale
186
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
187
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
188
    from fairscale.nn.wrap import auto_wrap
189
190
191
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

192

Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
195
196
197
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
198
199

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
200
201
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
202

203

204
205
206
207
if is_safetensors_available():
    import safetensors.torch


208
209
210
211
212
213
214
skip_first_batches = None
if is_accelerate_available():
    from accelerate import __version__ as accelerate_version

    if version.parse(accelerate_version) >= version.parse("0.16"):
        from accelerate import skip_first_batches

215
    from accelerate import Accelerator
216
    from accelerate.utils import DistributedDataParallelKwargs
217

218

219
220
221
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
222
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
223
224


225
226
227
228
229
230
231
232
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
233
234
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
235
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
236
237

    Args:
238
239
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
240

241
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
242

Sylvain Gugger's avatar
Sylvain Gugger committed
243
244
245
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.
246
247
248
249

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
250
251
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
252
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
253
254
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
255
256
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
257
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
258
259
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
260
261
262
263
264
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
265
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
266
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
267
268
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
269
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
270
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
271
272
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
273
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
274
275
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
276

277
278
279
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
280
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
281
282
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
283
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
284
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
285
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
286

287
288
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
289
290
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
291
292
293
294
295
296
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
297

298
299
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
300
301
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
302
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
303
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
304
305
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
306
307
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
308
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
309
310
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
311
312
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
313

Julien Chaumond's avatar
Julien Chaumond committed
314
315
    """

316
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
317

Julien Chaumond's avatar
Julien Chaumond committed
318
319
    def __init__(
        self,
320
        model: Union[PreTrainedModel, nn.Module] = None,
321
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
322
323
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
324
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
325
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
326
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
327
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
328
        callbacks: Optional[List[TrainerCallback]] = None,
329
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
330
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
331
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
332
        if args is None:
333
334
335
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
336
337
        self.args = args
        # Seed must be set before instantiating the model when using model
338
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
339
        self.hp_name = None
340
        self.is_in_train = False
341

342
        # create accelerator object
343
344
345
346
347
348
349
350
        self.accelerator = Accelerator(
            deepspeed_plugin=self.args.deepspeed_plugin,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
        )

        # deepspeed and accelerate flags covering both trainer args and accelerate launcher
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
351

352
        # post accelerator creation setup
353
        if self.is_fsdp_enabled:
354
355
356
357
            fsdp_plugin = self.accelerator.state.fsdp_plugin
            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
            fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

358
359
360
361
362
363
364
365
366
367
        if self.is_deepspeed_enabled:
            if getattr(self.args, "hf_deepspeed_config", None) is None:
                from transformers.deepspeed import HfTrainerDeepSpeedConfig

                ds_plugin = self.accelerator.state.deepspeed_plugin

                ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
                ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
                ds_plugin.hf_ds_config.trainer_config_process(self.args)

368
369
370
371
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

372
        # set the correct log level depending on the node
373
        log_level = args.get_process_log_level()
374
375
        logging.set_verbosity(log_level)

376
377
378
        # force device and distributed setup init explicitly
        args._setup_devices

379
380
381
382
383
384
385
386
387
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
388
389
390
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
391
392
393
                    FutureWarning,
                )
            self.model_init = model_init
394

395
396
397
398
399
400
401
402
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

403
404
405
406
407
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

408
409
410
411
412
413
        if getattr(model, "hf_device_map", None) is not None:
            devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
            if len(devices) > 1:
                self.is_model_parallel = True
            else:
                self.is_model_parallel = self.args.device != torch.device(devices[0])
414
415
416
417
418
419
420

            # warn users
            logger.info(
                "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
                " to `True` to avoid any unexpected behavior such as device placement mismatching."
            )

421
        # At this stage the model is already loaded
422
423
        if getattr(model, "is_quantized", False):
            if getattr(model, "_is_quantized_training_enabled", False):
424
425
426
427
428
429
430
431
432
433
434
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )
435

436
437
438
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
439
            if self.is_deepspeed_enabled:
440
441
442
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
443
444
445
446
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
447
            if args.parallel_mode != ParallelMode.DISTRIBUTED:
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

463
464
        self.fsdp = None
        if len(args.fsdp) > 0:
465
            if self.is_deepspeed_enabled:
466
467
468
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
469
            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
470
471
                raise ValueError("Using fsdp only works in distributed training.")

472
473
474
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
475
            # below is the current alternative.
476
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
477
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
478

479
            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
480
481
482
483
484

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
485
486
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
487

488
            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
489
            if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
490
491
                "backward_prefetch", []
            ):
492
493
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

Seung-Moo Yang's avatar
Seung-Moo Yang committed
494
495
496
            self.forward_prefetch = False
            if self.args.fsdp_config.get("forward_prefect", False):
                self.forward_prefetch = True
497

498
499
500
501
            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

502
        # one place to sort out whether to place the model on device or not
503
504
505
506
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
507
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
508
        # 4. Sharded DDP - same as MP
509
        # 5. FSDP - same as MP
510
        self.place_model_on_device = args.place_model_on_device
511
512
        if (
            self.is_model_parallel
513
            or self.is_deepspeed_enabled
514
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
515
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
516
            or (self.fsdp is not None)
517
            or self.is_fsdp_enabled
518
        ):
519
520
            self.place_model_on_device = False

521
522
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
523
524
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
525
        self.tokenizer = tokenizer
526

527
        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
Sylvain Gugger's avatar
Sylvain Gugger committed
528
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
529
530
531

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
532
            self.args._n_gpu = 1
533
534
535
536
537

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
538
        self.compute_metrics = compute_metrics
539
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
540
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
541
542
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
543
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
544
545
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
561
        if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
562
563
564
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
565
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
566
567
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
568
569
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
570
571
572
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
573
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
574

575
576
577
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

578
579
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
580
            self.init_git_repo(at_init=True)
581
582
583
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
584
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
585
586
                dist.barrier()

587
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
588
            os.makedirs(self.args.output_dir, exist_ok=True)
589

590
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
591
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
592

593
594
595
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

596
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
597
598
599
600
            raise ValueError(
                "The train_dataset does not implement __len__, max_steps has to be specified. "
                "The number of steps needs to be known in advance for the learning rate scheduler."
            )
601

602
603
604
605
606
607
608
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

609
        self._signature_columns = None
610

611
612
        # Mixed precision setup
        self.use_apex = False
613
614
        self.use_cuda_amp = False
        self.use_cpu_amp = False
615

616
617
618
619
620
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
638

639
        if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
640
            if args.half_precision_backend == "auto":
641
                if args.device == torch.device("cpu"):
642
643
644
645
646
647
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
648
                else:
649
                    args.half_precision_backend = "cuda_amp"
650

651
            logger.info(f"Using {args.half_precision_backend} half precision backend")
652

653
        self.do_grad_scaling = False
654
        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
655
            # deepspeed and SageMaker Model Parallel manage their own half precision
656
657
658
659
660
661
662
663
664
665
666
667
668
            if self.sharded_ddp is not None:
                if args.half_precision_backend == "cuda_amp":
                    self.use_cuda_amp = True
                    self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                    #  bf16 does not need grad scaling
                    self.do_grad_scaling = self.amp_dtype == torch.float16
                    if self.do_grad_scaling:
                        if self.sharded_ddp is not None:
                            self.scaler = ShardedGradScaler()
                        elif self.fsdp is not None:
                            from torch.distributed.fsdp.sharded_grad_scaler import (
                                ShardedGradScaler as FSDPShardedGradScaler,
                            )
669

670
671
672
                            self.scaler = FSDPShardedGradScaler()
                        elif is_torch_tpu_available():
                            from torch_xla.amp import GradScaler
673

674
675
676
677
678
679
680
                            self.scaler = GradScaler()
                        else:
                            self.scaler = torch.cuda.amp.GradScaler()
                elif args.half_precision_backend == "cpu_amp":
                    self.use_cpu_amp = True
                    self.amp_dtype = torch.bfloat16
            elif args.half_precision_backend == "apex":
681
682
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
683
684
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
685
686
687
                    )
                self.use_apex = True

688
689
690
691
692
693
694
695
696
697
698
699
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
700
701
702
703
704
705
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

706
707
708
709
710
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
711
        self.control = TrainerControl()
712
713
714
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
715
        self.hp_search_backend = None
716
        self.use_tune_checkpoints = False
717
        default_label_names = find_labels(self.model.__class__)
718
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
719
        self.can_return_loss = can_return_loss(self.model.__class__)
Sylvain Gugger's avatar
Sylvain Gugger committed
720
721
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

722
723
724
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

725
726
727
        # very last
        self._memory_tracker.stop_and_update_metrics()

728
729
        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
730
            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
731

Sylvain Gugger's avatar
Sylvain Gugger committed
732
733
    def add_callback(self, callback):
        """
734
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
735
736

        Args:
737
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
738
739
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
740
741
742
743
744
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
745
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
746

747
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
748
749

        Args:
750
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
751
752
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
753
754

        Returns:
755
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
756
757
758
759
760
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
761
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
762
763

        Args:
764
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
765
766
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
767
768
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
769

Sylvain Gugger's avatar
Sylvain Gugger committed
770
771
772
773
774
775
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

776
    def _set_signature_columns_if_needed(self):
777
778
779
780
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
781
782
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
783

784
785
786
787
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
788
        signature_columns = self._signature_columns
789
790

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
791
        if len(ignored_columns) > 0:
792
            dset_description = "" if description is None else f"in the {description} set"
793
794
795
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
796
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
797
                " you can safely ignore this message."
798
            )
799

800
        columns = [k for k in signature_columns if k in dataset.column_names]
801

802
803
804
805
806
807
808
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
809

810
811
812
813
814
815
816
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
817
        signature_columns = self._signature_columns
818
819
820
821
822
823
824
825
826
827

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

828
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
829
        if self.train_dataset is None or not has_length(self.train_dataset):
830
            return None
831

832
        generator = None
833
        if self.args.world_size <= 1:
834
            generator = torch.Generator()
835
836
837
838
839
840
841
842
843
844
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
845

846
847
        # Build the sampler.
        if self.args.group_by_length:
848
849
850
851
852
853
854
855
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
856
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
857
            if self.args.world_size <= 1:
858
                return LengthGroupedSampler(
859
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
860
                    dataset=self.train_dataset,
861
862
863
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
864
                )
865
866
            else:
                return DistributedLengthGroupedSampler(
867
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
868
                    dataset=self.train_dataset,
869
870
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
871
                    lengths=lengths,
872
                    model_input_name=model_input_name,
873
                    seed=seed,
874
875
876
                )

        else:
877
            if self.args.world_size <= 1:
878
                return RandomSampler(self.train_dataset, generator=generator)
Sylvain Gugger's avatar
Sylvain Gugger committed
879
880
881
882
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
883
884
885
886
887
888
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
889
                    seed=seed,
890
                )
891
            else:
892
                return DistributedSampler(
893
894
895
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
896
                    seed=seed,
897
                )
898
899
900

    def get_train_dataloader(self) -> DataLoader:
        """
901
        Returns the training [`~torch.utils.data.DataLoader`].
902

903
904
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
905
906
907
908
909

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
910

911
        train_dataset = self.train_dataset
912
        data_collator = self.data_collator
913
914
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
915
916
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
917

918
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
919
920
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
921
                    train_dataset,
922
                    batch_size=self._train_batch_size,
923
924
925
926
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
927

928
929
            return DataLoader(
                train_dataset,
930
                batch_size=self._train_batch_size,
931
                collate_fn=data_collator,
932
933
934
935
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

936
937
938
        train_sampler = self._get_train_sampler()

        return DataLoader(
939
            train_dataset,
940
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
941
            sampler=train_sampler,
942
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
943
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
944
            num_workers=self.args.dataloader_num_workers,
945
            pin_memory=self.args.dataloader_pin_memory,
946
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
947
948
        )

949
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
950
951
952
953
954
955
956
957
958
959
960
961
962
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
963
            elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
964
965
966
967
968
969
970
971
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
972
973
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
974
975
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
976
            )
Lysandre Debut's avatar
Lysandre Debut committed
977

Julien Chaumond's avatar
Julien Chaumond committed
978
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
979
        """
980
        Returns the evaluation [`~torch.utils.data.DataLoader`].
981

982
983
        Subclass and override this method if you want to inject some custom behavior.

984
        Args:
985
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
986
987
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
988
        """
Julien Chaumond's avatar
Julien Chaumond committed
989
990
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
991
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
992
        data_collator = self.data_collator
993

994
995
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
996
997
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
998

999
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
1000
1001
1002
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
1003
                    batch_size=self.args.per_device_eval_batch_size,
1004
1005
1006
1007
1008
1009
1010
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
1011
                collate_fn=data_collator,
1012
1013
1014
1015
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

1016
        eval_sampler = self._get_eval_sampler(eval_dataset)
1017

1018
        return DataLoader(
1019
            eval_dataset,
1020
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1021
            batch_size=self.args.eval_batch_size,
1022
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
1023
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
1024
            num_workers=self.args.dataloader_num_workers,
1025
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1026
1027
1028
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
1029
        """
1030
        Returns the test [`~torch.utils.data.DataLoader`].
1031

1032
1033
        Subclass and override this method if you want to inject some custom behavior.

1034
        Args:
1035
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1036
1037
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
1038
        """
1039
1040
        data_collator = self.data_collator

1041
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
1042
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
1043
1044
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
1045

1046
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
1058
                collate_fn=data_collator,
1059
1060
1061
1062
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

1063
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
1064

1065
1066
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
1067
            test_dataset,
1068
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1069
            batch_size=self.args.eval_batch_size,
1070
            collate_fn=data_collator,
1071
            drop_last=self.args.dataloader_drop_last,
1072
            num_workers=self.args.dataloader_num_workers,
1073
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1074
        )
Lysandre Debut's avatar
Lysandre Debut committed
1075

1076
    def create_optimizer_and_scheduler(self, num_training_steps: int):
1077
1078
1079
        """
        Setup the optimizer and the learning rate scheduler.

1080
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
1081
1082
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
1083
1084
        """
        self.create_optimizer()
1085
1086
1087
1088
1089
1090
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
1091
1092
1093
1094
1095

    def create_optimizer(self):
        """
        Setup the optimizer.

1096
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
1097
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1098
        """
1099
1100
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

1101
        if self.optimizer is None:
1102
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
1103
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
1104
1105
            optimizer_grouped_parameters = [
                {
1106
1107
1108
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
1109
1110
1111
                    "weight_decay": self.args.weight_decay,
                },
                {
1112
1113
1114
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
1115
1116
1117
                    "weight_decay": 0.0,
                },
            ]
1118
1119
1120

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

1121
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1122
1123
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
1124
1125
                    optim=optimizer_cls,
                    **optimizer_kwargs,
1126
1127
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1128
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1129
1130
1131
1132
1133
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

Stas Bekman's avatar
Stas Bekman committed
1134
                    skipped = 0
1135
                    for module in opt_model.modules():
1136
                        if isinstance(module, nn.Embedding):
1137
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
1138
                            logger.info(f"skipped {module}: {skipped/2**20}M params")
1139
1140
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
1141
                    logger.info(f"skipped: {skipped/2**20}M params")
Sylvain Gugger's avatar
Sylvain Gugger committed
1142

Sylvain Gugger's avatar
Sylvain Gugger committed
1143
1144
1145
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1146
1147
        return self.optimizer

1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
1158
1159
1160
1161
1162
1163
1164
1165

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

1166
        optimizer_kwargs = {"lr": args.learning_rate}
1167

1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1180
        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
1181
1182
1183
1184
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1185
1186
            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
                optimizer_kwargs.update({"fused": True})
1187
1188
1189
1190
1191
1192
1193
1194
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1195
1196
1197
1198
1199
1200
1201
1202
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
        elif args.optim in [
            OptimizerNames.ADAMW_BNB,
            OptimizerNames.ADAMW_8BIT,
            OptimizerNames.PAGED_ADAMW,
            OptimizerNames.PAGED_ADAMW_8BIT,
            OptimizerNames.LION,
            OptimizerNames.LION_8BIT,
            OptimizerNames.PAGED_LION,
            OptimizerNames.PAGED_LION_8BIT,
        ]:
            try:
                from bitsandbytes.optim import AdamW, Lion

                is_paged = False
                optim_bits = 32
                optimizer_cls = None
                additional_optim_kwargs = adam_kwargs
                if "paged" in args.optim:
                    is_paged = True
                if "8bit" in args.optim:
                    optim_bits = 8
                if "adam" in args.optim:
                    optimizer_cls = AdamW
                elif "lion" in args.optim:
                    optimizer_cls = Lion
                    additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}

                bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
                optimizer_kwargs.update(additional_optim_kwargs)
                optimizer_kwargs.update(bnb_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
1235
1236
1237
1238
1239
1240
1241
1242
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
                        "compensation_buffer_dtype": getattr(
                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
                        ),
                    }
                )
            except ImportError:
                raise ValueError("Please install https://github.com/pytorch/torchdistx")
1263
1264
1265
1266
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1267
1268
1269
1270
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1271
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1272
        """
1273
1274
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1275
1276
1277
1278

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1279
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1280
1281
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1282
                optimizer=self.optimizer if optimizer is None else optimizer,
1283
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1284
                num_training_steps=num_training_steps,
1285
            )
1286
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1287

1288
    def num_examples(self, dataloader: DataLoader) -> int:
1289
        """
1290
1291
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1292
        """
1293
        try:
1294
1295
1296
1297
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1298
1299
1300
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1301

1302
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1303
        """HP search setup code"""
1304
1305
        self._trial = trial

1306
1307
        if self.hp_search_backend is None or trial is None:
            return
1308
1309
1310
1311
1312
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1313
1314
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1315
1316
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1317

1318
1319
        for key, value in params.items():
            if not hasattr(self.args, key):
1320
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1321
1322
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1323
                )
1324
                continue
1325
1326
1327
1328
1329
1330
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1331
            logger.info(f"Trial: {trial.params}")
1332
1333
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1334
1335
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1336
1337
1338
        if self.is_deepspeed_enabled:
            if self.args.deepspeed is None:
                raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
1339
            # Rebuild the deepspeed config to reflect the updated training parameters
1340
1341
            from accelerate.utils import DeepSpeedPlugin

1342
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1343

1344
1345
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1346
            self.accelerator.state.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
1347

1348
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1349
1350
        if self.hp_search_backend is None or trial is None:
            return
1351
        self.objective = self.compute_objective(metrics.copy())
1352
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1353
1354
            import optuna

1355
            trial.report(self.objective, step)
1356
            if trial.should_prune():
1357
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1358
1359
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1360
1361
            from ray import tune

1362
            if self.control.should_save:
1363
                self._tune_save_checkpoint()
1364
1365
            tune.report(objective=self.objective, **metrics)

1366
    def _tune_save_checkpoint(self):
1367
1368
        from ray import tune

1369
1370
        if not self.use_tune_checkpoints:
            return
1371
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1372
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1373
            self.save_model(output_dir, _internal_call=True)
1374
            if self.args.should_save:
1375
1376
1377
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1378

1379
    def call_model_init(self, trial=None):
1380
        model_init_argcount = number_of_arguments(self.model_init)
1381
1382
1383
1384
1385
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1386
1387
1388
1389
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1390
1391
1392

        return model

1393
1394
1395
1396
1397
1398
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            example_batch = next(iter(dataloader))
1399
            example_batch = self._prepare_inputs(example_batch)
1400
1401
            try:
                jit_model = model.eval()
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
                with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
1419
                jit_model = torch.jit.freeze(jit_model)
1420
1421
1422
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
1423
                model = jit_model
1424
1425
1426
                self.use_cpu_amp = False
                self.use_cuda_amp = False
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
1427
1428
1429
1430
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1431
1432
1433
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1434
1435
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1436
1437
1438
1439
1440
1441
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
1442
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
1443
            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
1444
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
1445
1446
1447
        else:
            if not model.training:
                model.train()
1448
1449
1450
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1451
1452
1453

        return model

1454
    def _wrap_model(self, model, training=True, dataloader=None):
1455
1456
1457
1458
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

Sylvain Gugger's avatar
Sylvain Gugger committed
1459
1460
1461
1462
1463
1464
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1465
1466
1467
1468
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1469
1470
1471
1472
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

1473
1474
        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
1475
            model = nn.DataParallel(model)
1476

1477
        if self.args.jit_mode_eval:
1478
            start_time = time.time()
1479
            model = self.torch_jit_model_eval(model, dataloader, training)
1480
            self.jit_compilation_time = round(time.time() - start_time, 4)
1481

1482
1483
1484
1485
1486
1487
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1488
1489
1490
1491
1492
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1493
                mixed_precision = self.args.fp16 or self.args.bf16
1494
1495
1496
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1497
1498
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1499
                self.model = model = FullyShardedDDP(
1500
1501
1502
1503
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1504
                ).to(self.args.device)
1505
        # Distributed training using PyTorch FSDP
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
        elif self.fsdp is not None and self.args.fsdp_config["xla"]:
            try:
                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                from torch_xla.distributed.fsdp import checkpoint_module
                from torch_xla.distributed.fsdp.wrap import (
                    size_based_auto_wrap_policy,
                    transformer_auto_wrap_policy,
                )
            except ImportError:
                raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
            auto_wrap_policy = None
            auto_wrapper_callable = None
            if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                auto_wrap_policy = functools.partial(
                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                )
            elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                transformer_cls_to_wrap = set()
                for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                    transformer_cls = get_module_class_from_name(model, layer_class)
                    if transformer_cls is None:
                        raise Exception("Could not find the transformer layer class to wrap in the model.")
                    else:
                        transformer_cls_to_wrap.add(transformer_cls)
                auto_wrap_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    # Transformer layer class to wrap
                    transformer_layer_cls=transformer_cls_to_wrap,
1534
                )
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
            fsdp_kwargs = self.args.xla_fsdp_config
            if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                def auto_wrapper_callable(m, *args, **kwargs):
                    return FSDP(checkpoint_module(m), *args, **kwargs)

            # Wrap the base model with an outer FSDP wrapper
            self.model = model = FSDP(
                model,
                auto_wrap_policy=auto_wrap_policy,
                auto_wrapper_callable=auto_wrapper_callable,
                **fsdp_kwargs,
            )
1548

1549
1550
1551
1552
1553
1554
1555
            # Patch `xm.optimizer_step` should not reduce gradients in this case,
            # as FSDP does not need gradient reduction over sharded parameters.
            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                loss = optimizer.step(**optimizer_args)
                if barrier:
                    xm.mark_step()
                return loss
1556

1557
            xm.optimizer_step = patched_optimizer_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1558
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1559
1560
1561
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1562
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
1563
1564
            if is_torch_neuroncore_available():
                return model
1565
            kwargs = {}
1566
            if self.args.ddp_find_unused_parameters is not None:
1567
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1568
1569
1570
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1571
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1572
            else:
1573
1574
1575
1576
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1577
1578

            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
1579
1580
1581

        return model

1582
1583
    def train(
        self,
1584
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1585
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1586
        ignore_keys_for_eval: Optional[List[str]] = None,
1587
        **kwargs,
1588
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1589
1590
1591
1592
        """
        Main training entry point.

        Args:
1593
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1594
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1595
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1596
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1597
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1598
                The trial run or the hyperparameter dictionary for hyperparameter search.
1599
            ignore_keys_for_eval (`List[str]`, *optional*)
1600
1601
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1602
1603
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1604
        """
1605
1606
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1607
1608
1609
1610

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1611
1612
        args = self.args

1613
1614
        self.is_in_train = True

1615
1616
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1617
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1618
            self._move_model_to_device(self.model, args.device)
1619

1620
1621
1622
1623
1624
1625
1626
1627
1628
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1629
1630
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1631
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1632

1633
        # Model re-init
1634
        model_reloaded = False
1635
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1636
            # Seed must be set before instantiating the model when using model_init.
1637
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1638
1639
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1640
1641
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1642

1643
        # Load potential model checkpoint
1644
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1645
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1646
            if resume_from_checkpoint is None:
1647
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1648

1649
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
1650
            self._load_from_checkpoint(resume_from_checkpoint)
1651

1652
1653
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1654
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1655
                self._move_model_to_device(self.model, args.device)
1656
1657
            self.model_wrapped = self.model

1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self._train_batch_size = batch_size
1672
        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
1673
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1674
        train_dataloader = self.get_train_dataloader()
1675
1676
1677
1678
1679

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1680
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1681
1682
1683
1684
1685

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1686
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1687
            num_examples = self.num_examples(train_dataloader)
1688
1689
1690
1691
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1692
                )
1693
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1694
1695
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1696
            else:
1697
1698
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1699
1700
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1701
            max_steps = args.max_steps
1702
1703
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1704
            num_update_steps_per_epoch = max_steps
1705
            num_examples = total_train_batch_size * args.max_steps
1706
            num_train_samples = args.max_steps * total_train_batch_size
1707
1708
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1709
1710
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1711
            )
Julien Chaumond's avatar
Julien Chaumond committed
1712

1713
1714
1715
1716
1717
1718
1719
1720
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps and args.logging_steps < 1:
            args.logging_steps = math.ceil(max_steps * args.logging_steps)
        if args.eval_steps and args.eval_steps < 1:
            args.eval_steps = math.ceil(max_steps * args.eval_steps)
        if args.save_steps and args.save_steps < 1:
            args.save_steps = math.ceil(max_steps * args.save_steps)

1721
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1722
1723
1724
1725
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1726
1727
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1728
1729
1730
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1731

1732
        delay_optimizer_creation = (
1733
1734
1735
1736
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1737
        )
1738
1739
1740
1741
1742

        if self.is_deepspeed_enabled:
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

        if not delay_optimizer_creation:
1743
1744
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1745
        self.state = TrainerState()
1746
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1747

1748
1749
1750
1751
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1752
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1753

1754
1755
1756
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1757
1758
1759
1760
        # as the model is wrapped, don't use `accelerator.prepare`
        # this is for unhandled cases such as
        # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
        use_accelerator_prepare = True if model is self.model else False
1761

1762
1763
1764
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1765
        # prepare using `accelerator` prepare
1766
1767
1768
1769
1770
        if use_accelerator_prepare:
            model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                self.model, self.optimizer, self.lr_scheduler
            )

1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
        if self.is_fsdp_enabled:
            self.model = model

        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model_wrapped

        # deepspeed ckpt loading
        if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
            deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)

1786
1787
1788
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1789
1790
        # important: at this point:
        # self.model         is the Transformers Model
1791
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1792

Julien Chaumond's avatar
Julien Chaumond committed
1793
1794
        # Train!
        logger.info("***** Running training *****")
1795
1796
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
1797
        logger.info(f"  Instantaneous batch size per device = {self._train_batch_size:,}")
1798
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
1799
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1800
1801
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
Julien Chaumond's avatar
Julien Chaumond committed
1802

1803
        self.state.epoch = 0
1804
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1805
1806
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1807
        steps_trained_progress_bar = None
1808

Julien Chaumond's avatar
Julien Chaumond committed
1809
        # Check if continuing training from a checkpoint
1810
        if resume_from_checkpoint is not None and os.path.isfile(
1811
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1812
        ):
1813
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1814
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1815
            if not args.ignore_data_skip:
1816
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1817
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1818
1819
            else:
                steps_trained_in_current_epoch = 0
1820
1821

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1822
1823
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1824
            if not args.ignore_data_skip:
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
                if skip_first_batches is None:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
                        " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
                        " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
                        " training on data already seen by your model."
                    )
                else:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch."
                    )
                if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
1839
1840
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1841

Sylvain Gugger's avatar
Sylvain Gugger committed
1842
1843
1844
1845
1846
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1847
1848
1849
1850
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
1851
1852
1853
1854
1855
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1856
1857
1858
1859
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1860
1861
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1862

1863
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1864
        tr_loss = torch.tensor(0.0).to(args.device)
1865
1866
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1867
        self._globalstep_last_logged = self.state.global_step
1868
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1869

1870
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1871

1872
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1873
        if not args.ignore_data_skip:
1874
            for epoch in range(epochs_trained):
1875
1876
1877
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
1878
                if is_torch_less_than_1_11 or not is_random_sampler:
1879
1880
1881
1882
1883
1884
1885
1886
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1887

1888
        total_batched_samples = 0
1889
        for epoch in range(epochs_trained, num_train_epochs):
1890
1891
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1892
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1893
                train_dataloader.dataset.set_epoch(epoch)
1894

1895
            if is_torch_tpu_available():
1896
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1897
                epoch_iterator = parallel_loader
1898
            else:
1899
                epoch_iterator = train_dataloader
1900

1901
            # Reset the past mems state at the beginning of each epoch if necessary.
1902
            if args.past_index >= 0:
1903
1904
                self._past = None

1905
            steps_in_epoch = (
1906
1907
1908
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1909
            )
1910
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1911

1912
1913
1914
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1915
            rng_to_sync = False
1916
            steps_skipped = 0
1917
1918
            if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
1919
                steps_skipped = steps_trained_in_current_epoch
1920
1921
1922
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

1923
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1924
            for step, inputs in enumerate(epoch_iterator):
1925
                total_batched_samples += 1
1926
1927
1928
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False
Julien Chaumond's avatar
Julien Chaumond committed
1929
1930
1931
1932

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1933
1934
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1935
1936
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1937
                    continue
1938
1939
1940
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1941

1942
1943
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1944

1945
                with self.accelerator.accumulate(model):
1946
1947
                    tr_loss_step = self.training_step(model, inputs)

1948
1949
1950
1951
1952
1953
1954
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1955
1956
1957
                else:
                    tr_loss += tr_loss_step

1958
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1959

1960
1961
1962
                # should this be under the accumulate context manager?
                # the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered
                # in accelerate
1963
                if total_batched_samples % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1964
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1965
                    steps_in_epoch <= args.gradient_accumulation_steps
1966
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1967
                ):
1968
                    # Gradient clipping
1969
                    if args.max_grad_norm is not None and args.max_grad_norm > 0:
1970
1971
                        # deepspeed does its own clipping

1972
                        if self.do_grad_scaling:
1973
1974
1975
1976
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1977
1978
1979
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1980
1981
1982
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1983
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1984
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1985
1986
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1987
                            model.clip_grad_norm_(args.max_grad_norm)
1988
                        elif self.use_apex:
1989
                            # Revert to normal clipping otherwise, handling Apex or full precision
1990
                            nn.utils.clip_grad_norm_(
1991
1992
1993
1994
1995
1996
                                amp.master_params(self.optimizer),
                                args.max_grad_norm,
                            )
                        else:
                            self.accelerator.clip_grad_norm_(
                                model.parameters(),
1997
                                args.max_grad_norm,
1998
1999
2000
                            )

                    # Optimizer step
2001
                    optimizer_was_run = True
2002
                    if is_torch_tpu_available():
2003
2004
2005
2006
2007
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
2008
                    elif self.do_grad_scaling:
2009
                        scale_before = self.scaler.get_scale()
2010
                        self.scaler.step(self.optimizer)
2011
                        self.scaler.update()
2012
2013
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
2014
                    else:
2015
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
2016

2017
                    if optimizer_was_run:
2018
2019
2020
                        # Delay optimizer scheduling until metrics are generated
                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                            self.lr_scheduler.step()
2021

2022
                    model.zero_grad()
2023
                    self.state.global_step += 1
2024
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2025
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
2026

2027
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
2028
2029
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2030

Sylvain Gugger's avatar
Sylvain Gugger committed
2031
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
2032
                    break
2033
2034
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2035
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
2036
2037
2038
2039
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
2040

2041
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
2042
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
2043

2044
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
2045
2046
2047
2048
2049
2050
2051
2052
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2053
            if self.control.should_training_stop:
2054
                break
Julien Chaumond's avatar
Julien Chaumond committed
2055

2056
        if args.past_index and hasattr(self, "_past"):
2057
2058
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2059
2060

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
2061
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
2062
2063
2064
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
2065
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
2066
                dist.barrier()
2067
2068
            elif is_sagemaker_mp_enabled():
                smp.barrier()
2069

2070
            self._load_best_model()
2071

2072
2073
2074
2075
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

2076
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
2077
2078
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
2079
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2080

2081
        self.is_in_train = False
2082

2083
2084
        self._memory_tracker.stop_and_update_metrics(metrics)

2085
2086
        self.log(metrics)

raghavanone's avatar
raghavanone committed
2087
2088
2089
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

2090
2091
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
raghavanone's avatar
raghavanone committed
2092
2093
2094
2095
2096
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

2097
2098
2099
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
2100

raghavanone's avatar
raghavanone committed
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

2121
2122
2123
2124
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

2125
2126
2127
2128
2129
2130
2131
2132
        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)

        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

        if not any(
Qingyang Wu's avatar
Qingyang Wu committed
2133
            os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
2134
        ):
2135
2136
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

2137
        logger.info(f"Loading model from {resume_from_checkpoint}.")
2138

2139
2140
        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
2141
2142
2143
2144
2145
2146
2147
2148
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

2149
        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
2150
            # If the model is on the GPU, it still works!
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
2165
                    state_dict = torch.load(weights_file, map_location="cpu")
2166
2167
2168
2169
2170
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
2171
2172
            elif self.is_fsdp_enabled:
                self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint)
2173
2174
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
2175
2176
2177
2178
2179
                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
                else:
                    state_dict = torch.load(weights_file, map_location="cpu")

2180
2181
2182
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
2183
2184
                # release memory
                del state_dict
2185
                self._issue_warnings_after_load(load_result)
2186
2187
        else:
            # We load the sharded checkpoint
2188
2189
2190
            load_result = load_sharded_checkpoint(
                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
            )
2191
            if not is_sagemaker_mp_enabled():
2192
                self._issue_warnings_after_load(load_result)
2193
2194
2195
2196

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
2197
        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
2198
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2199
        if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
2200
2201
            if self.is_deepspeed_enabled:
                deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
2202
            else:
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
2216
2217
2218
2219
2220
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")

2221
2222
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
2223
2224
2225
2226
                elif self.is_fsdp_enabled:
                    self.accelerator.state.fsdp_plugin.load_model(
                        self.accelerator, model, self.state.best_model_checkpoint
                    )
2227
                else:
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
                    if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
                        # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
                        if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
                            if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
                                model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
                                # Load_adapter has no return value present, modify it when appropriate.
                                from torch.nn.modules.module import _IncompatibleKeys

                                load_result = _IncompatibleKeys([], [])
                            else:
                                logger.warning(
                                    "The intermediate checkpoints of PEFT may not be saved correctly, "
                                    "using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
                                    "here are some examples https://github.com/huggingface/peft/issues/96"
                                )
                        else:
                            # We can't do pure 8bit training using transformers.
                            logger.warning("Could not loading a quantized checkpoint.")
2246
                    else:
2247
2248
2249
2250
2251
                        # We load the model state dict on the CPU to avoid an OOM error.
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")
2252

2253
2254
2255
2256
                        # If the model is on the GPU, it still works!
                        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                        # which takes *args instead of **kwargs
                        load_result = model.load_state_dict(state_dict, False)
2257
                if not is_sagemaker_mp_enabled():
2258
                    self._issue_warnings_after_load(load_result)
2259
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2260
2261
2262
2263
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2264
                self._issue_warnings_after_load(load_result)
2265
2266
2267
2268
2269
2270
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2271
    def _issue_warnings_after_load(self, load_result):
2272
        if len(load_result.missing_keys) != 0:
2273
2274
2275
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2276
2277
                self.model.tie_weights()
            else:
2278
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2279
        if len(load_result.unexpected_keys) != 0:
2280
2281
2282
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2283

2284
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2285
        if self.control.should_log:
2286
2287
2288
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2289
            logs: Dict[str, float] = {}
2290
2291
2292
2293

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2294
2295
2296
            # reset tr_loss to zero
            tr_loss -= tr_loss

2297
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2298
            logs["learning_rate"] = self._get_learning_rate()
2299

2300
            self._total_loss_scalar += tr_loss_scalar
2301
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2302
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2303
2304
2305
2306
2307

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2308
            if isinstance(self.eval_dataset, dict):
2309
                metrics = {}
2310
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
2311
                    dataset_metrics = self.evaluate(
2312
2313
2314
2315
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
2316
                    metrics.update(dataset_metrics)
2317
2318
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2319
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2320

2321
2322
2323
2324
            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.lr_scheduler.step(metrics[self.args.metric_for_best_model])

Sylvain Gugger's avatar
Sylvain Gugger committed
2325
2326
2327
2328
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2329
2330
2331
2332
2333
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2334
2335
2336
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2337
            if not os.path.isfile(rng_file):
2338
                logger.info(
2339
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2340
2341
2342
2343
2344
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2345
            if not os.path.isfile(rng_file):
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
2357
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2358
                torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
2359
            else:
2360
                try:
2361
                    torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
2362
                except Exception as e:
2363
                    logger.info(
2364
2365
2366
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2367
2368
2369
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2370
    def _save_checkpoint(self, model, trial, metrics=None):
2371
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2372
        # want to save except FullyShardedDDP.
2373
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2374

2375
        # Save model checkpoint
2376
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2377

raghavanone's avatar
raghavanone committed
2378
        if self.hp_search_backend is None and trial is None:
2379
            self.store_flos()
2380

raghavanone's avatar
raghavanone committed
2381
        run_dir = self._get_output_dir(trial=trial)
2382
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2383
        self.save_model(output_dir, _internal_call=True)
2384
        if self.is_deepspeed_enabled:
2385
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2386
            # config `stage3_gather_16bit_weights_on_model_save` is True
2387
            self.model_wrapped.save_checkpoint(output_dir)
2388
2389

        # Save optimizer and scheduler
2390
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2391
            self.optimizer.consolidate_state_dict()
2392

Qingyang Wu's avatar
Qingyang Wu committed
2393
2394
2395
2396
2397
2398
        if self.fsdp:
            # FSDP has a different interface for saving optimizer states.
            # Needs to be called on all ranks to gather all states.
            # full_optim_state_dict will be deprecated after Pytorch 2.2!
            full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)

2399
2400
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2401
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2402
            with warnings.catch_warnings(record=True) as caught_warnings:
2403
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2404
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2405
        elif is_sagemaker_mp_enabled():
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2421
        elif self.args.should_save and not self.is_deepspeed_enabled:
2422
            # deepspeed.save_checkpoint above saves model/optim/sched
Qingyang Wu's avatar
Qingyang Wu committed
2423
2424
2425
2426
2427
            if self.fsdp:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

2428
            with warnings.catch_warnings(record=True) as caught_warnings:
2429
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2430
            reissue_pt_warnings(caught_warnings)
2431
            if self.do_grad_scaling:
2432
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2433
2434

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2435
        if metrics is not None and self.args.metric_for_best_model is not None:
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2451
        if self.args.should_save:
2452
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2453

2454
2455
2456
2457
2458
2459
2460
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
2461
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2462
2463
2464
2465
2466
2467
2468
2469
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2470
2471
2472
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2473

2474
        if self.args.world_size <= 1:
2475
2476
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2477
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2478

2479
2480
2481
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2482
        # Maybe delete some older checkpoints.
2483
        if self.args.should_save:
2484
2485
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2486
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2487
        """If optimizer and scheduler states exist, load them."""
2488
        if checkpoint is None:
2489
2490
            return

2491
        if self.is_deepspeed_enabled:
2492
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2493
2494
            return

2495
2496
2497
2498
2499
2500
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2501
2502
2503
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2504
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2505
                with warnings.catch_warnings(record=True) as caught_warnings:
2506
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2507
2508
2509
2510
2511
2512
2513
2514
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
2515
                if is_sagemaker_mp_enabled():
2516
2517
2518
2519
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2520

2521
2522
2523
2524
2525
2526
2527
2528
2529
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2530
2531
2532

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
2533
2534
2535
2536
                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
Qingyang Wu's avatar
Qingyang Wu committed
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
                    if self.fsdp:
                        full_osd = None
                        # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
                        if self.args.process_index == 0:
                            full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
                        # call scatter_full_optim_state_dict on all ranks
                        sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
                        self.optimizer.load_state_dict(sharded_osd)
                    else:
                        self.optimizer.load_state_dict(
                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                        )
Sylvain Gugger's avatar
Sylvain Gugger committed
2549
                with warnings.catch_warnings(record=True) as caught_warnings:
2550
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2551
                reissue_pt_warnings(caught_warnings)
2552
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2553
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2554

2555
2556
2557
2558
2559
2560
2561
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2562
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2563
        **kwargs,
2564
2565
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2566
2567
2568
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2569

2570
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2571

Sylvain Gugger's avatar
Sylvain Gugger committed
2572
2573
2574
2575
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2576
2577

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2578

2579
        Args:
2580
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2581
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2582
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2583
2584
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2585
2586
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2587
            n_trials (`int`, *optional*, defaults to 100):
2588
                The number of trial runs to test.
2589
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2590
2591
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2592
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2593
2594
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2595
2596
2597
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2598
2599
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2600

Sylvain Gugger's avatar
Sylvain Gugger committed
2601
2602
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2603
2604
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2605
2606

        Returns:
2607
2608
            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
            `run_summary` attribute for Ray backend.
2609
2610
2611
2612
2613
2614
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2615
2616
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2617
                    "To install sigopt run `pip install sigopt`."
2618
2619
2620
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2621
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2622
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2623
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2624
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2625
            )
2626
2627
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2628
2629
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2630
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2631
2632
2633
2634
2635
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2636
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2637
        self.hp_name = hp_name
2638
2639
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2640
2641
2642
2643
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2644
            HPSearchBackend.WANDB: run_hp_search_wandb,
2645
2646
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2647
2648
2649
2650

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2651
    def log(self, logs: Dict[str, float]) -> None:
2652
        """
2653
        Log `logs` on the various objects watching training.
2654
2655
2656
2657

        Subclass and override this method to inject custom behavior.

        Args:
2658
            logs (`Dict[str, float]`):
2659
2660
                The values to log.
        """
2661
        if self.state.epoch is not None:
2662
            logs["epoch"] = round(self.state.epoch, 2)
2663

2664
2665
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2666
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2667

2668
2669
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2670
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2671
        """
2672
2673
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2674
2675
2676
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
2677
            kwargs = {"device": self.args.device}
2678
            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
2679
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2680
2681
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
2682
                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
2683
2684
2685
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2686
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2687
        """
2688
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2689
2690
        handling potential state.
        """
2691
        inputs = self._prepare_input(inputs)
2692
2693
2694
2695
2696
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2697
2698
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2699

2700
2701
        return inputs

2702
2703
2704
2705
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
2706
        return self.autocast_smart_context_manager()
2707

2708
    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
2709
        """
2710
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2711
2712
        arguments, depending on the situation.
        """
2713
        if self.use_cuda_amp or self.use_cpu_amp:
2714
            if is_torch_greater_or_equal_than_1_10:
2715
                ctx_manager = (
2716
                    torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2717
                    if self.use_cpu_amp
2718
                    else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2719
                )
2720
            else:
2721
                ctx_manager = torch.cuda.amp.autocast()
2722
2723
2724
2725
2726
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2727
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2728
        """
2729
        Perform a training step on a batch of inputs.
2730
2731
2732
2733

        Subclass and override to inject custom behavior.

        Args:
2734
            model (`nn.Module`):
2735
                The model to train.
2736
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2737
2738
2739
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2740
                argument `labels`. Check your model's documentation for all accepted arguments.
2741
2742

        Return:
2743
            `torch.Tensor`: The tensor with training loss on this batch.
2744
2745
        """
        model.train()
2746
        inputs = self._prepare_inputs(inputs)
2747

Sylvain Gugger's avatar
Sylvain Gugger committed
2748
        if is_sagemaker_mp_enabled():
2749
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2750
2751
            return loss_mb.reduce_mean().detach().to(self.args.device)

2752
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2753
            loss = self.compute_loss(model, inputs)
2754

Julien Chaumond's avatar
Julien Chaumond committed
2755
2756
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2757

2758
        if self.do_grad_scaling:
2759
            self.scaler.scale(loss).backward()
2760
        elif self.use_apex:
2761
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2762
2763
                scaled_loss.backward()
        else:
2764
            self.accelerator.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2765

2766
        return loss.detach() / self.args.gradient_accumulation_steps
Julien Chaumond's avatar
Julien Chaumond committed
2767

2768
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2769
2770
2771
2772
2773
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2774
2775
2776
2777
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2778
2779
        outputs = model(**inputs)
        # Save past state if it exists
2780
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2781
2782
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2783

2784
        if labels is not None:
2785
2786
2787
2788
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2789
        else:
2790
2791
2792
2793
2794
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2795
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2796
2797
2798
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2799

2800
2801
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2802
2803
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2804
        """
2805
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2806

2807
2808
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2809
        Whether or not this process is the global main process (when training in a distributed fashion on several
2810
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2811
        """
2812
2813
2814
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2815
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2816
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2817
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2818

Sylvain Gugger's avatar
Sylvain Gugger committed
2819
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2820
        """
2821
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2822

2823
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2824
        """
2825
2826
2827
2828

        if output_dir is None:
            output_dir = self.args.output_dir

2829
        if is_torch_tpu_available():
2830
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2831
2832
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2833
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2834
            state_dict = self.model_wrapped.state_dict()
2835
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2836
                self._save(output_dir, state_dict=state_dict)
2837
2838
2839
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2840
        elif (
2841
2842
2843
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2844
            or self.is_fsdp_enabled
2845
        ):
2846
            if self.is_fsdp_enabled:
2847
2848
2849
                self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
            else:
                state_dict = self.model.state_dict()
2850

2851
2852
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
2853
        elif self.is_deepspeed_enabled:
2854
            # this takes care of everything as long as we aren't under zero3
2855
            if self.args.should_save:
2856
2857
2858
2859
2860
2861
2862
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2863
                if self.args.should_save:
2864
2865
2866
2867
2868
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2869
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2870
2871
                # if false it will not be saved.
                # This must be called on all ranks
2872
                if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
2873
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2874
2875
2876
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2877
                    )
2878
                    self.model_wrapped.save_checkpoint(output_dir)
2879

2880
        elif self.args.should_save:
2881
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2882

Sylvain Gugger's avatar
Sylvain Gugger committed
2883
2884
2885
2886
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2887
2888
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2889
        logger.info(f"Saving model checkpoint to {output_dir}")
2890
2891
2892

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2893
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2894
2895
2896
2897

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2898
        if not isinstance(self.model, PreTrainedModel):
2899
2900
2901
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2902
                    is_main_process=self.args.should_save,
2903
2904
2905
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2906
2907
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2908
2909
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2910
        else:
2911
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2912
        if self.tokenizer is not None and self.args.should_save:
2913
            self.tokenizer.save_pretrained(output_dir)
2914

2915
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2916
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2917
2918
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2919
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2920
2921
2922
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2923
2924
2925
            if state_dict is None:
                state_dict = self.model.state_dict()

2926
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2927
2928
2929
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
2930
2931
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2932
2933
2934
2935
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2936
        else:
2937
2938
2939
2940
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

2941
        if self.tokenizer is not None:
2942
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2943
2944

        # Good practice: save your training arguments together with the trained model
2945
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2946

2947
    def store_flos(self):
2948
        # Storing the number of floating-point operations that went into the model
2949
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2950
2951
2952
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2953
2954
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2955
            self.state.total_flos += self.current_flos
2956
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2957

2958
2959
2960
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2961
2962
        ordering_and_checkpoint_path = []

2963
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2964
2965
2966
2967
2968

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2969
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2970
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2971
2972
2973
2974
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2975
2976
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2977
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2978
2979
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2980
2981
        return checkpoints_sorted

2982
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2983
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2984
2985
2986
            return

        # Check if we should delete older checkpoint(s)
2987
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2988
2989
2990
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2991
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
3002
3003
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
3004
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
3005
            shutil.rmtree(checkpoint, ignore_errors=True)
Julien Chaumond's avatar
Julien Chaumond committed
3006

3007
    def evaluate(
3008
3009
3010
3011
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3012
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
3013
        """
3014
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
3015

Sylvain Gugger's avatar
Sylvain Gugger committed
3016
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
3017
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
3018

3019
3020
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
3021
        Args:
3022
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
3023
3024
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
3025
                method.
3026
            ignore_keys (`List[str]`, *optional*):
3027
3028
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3029
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
3030
3031
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
3032

Julien Chaumond's avatar
Julien Chaumond committed
3033
        Returns:
3034
3035
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
3036
        """
3037
3038
3039
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3040
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
3041
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
3042

3043
3044
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3045
3046
3047
3048
3049
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
3050
            ignore_keys=ignore_keys,
3051
            metric_key_prefix=metric_key_prefix,
3052
        )
Lysandre Debut's avatar
Lysandre Debut committed
3053

3054
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3055
3056
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3057
3058
3059
3060
3061
3062
3063
3064
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3065

3066
        self.log(output.metrics)
3067

3068
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
3069
3070
3071
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
3072
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
3073
3074
3075

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
3076
3077
        return output.metrics

3078
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
3079
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
3080
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3081
        """
3082
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
3083

Sylvain Gugger's avatar
Sylvain Gugger committed
3084
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
3085
        will also return metrics, like in `evaluate()`.
3086
3087

        Args:
3088
3089
3090
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
3091
            ignore_keys (`List[str]`, *optional*):
3092
3093
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3094
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
3095
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
3096
                "test_bleu" if the prefix is "test" (default)
3097

3098
3099
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
3100
3101
3102
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
3103

3104
        </Tip>
3105

3106
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
3107

3108
3109
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
3110
3111
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
3112
        """
3113
3114
3115
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3116
        test_dataloader = self.get_test_dataloader(test_dataset)
3117
        start_time = time.time()
3118

3119
3120
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3121
3122
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
3123
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3124
3125
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3126
3127
3128
3129
3130
3131
3132
3133
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3134

3135
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
3136
3137
        self._memory_tracker.stop_and_update_metrics(output.metrics)

3138
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
3139

3140
    def evaluation_loop(
3141
3142
3143
3144
3145
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
3146
        metric_key_prefix: str = "eval",
3147
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3148
        """
3149
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
3150
3151
3152

        Works both with or without labels.
        """
3153
3154
3155
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
3156

3157
        # if eval is called w/o train init deepspeed here
3158
3159
3160
3161
        if self.is_deepspeed_enabled and self.model_wrapped is self.model:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
            model = self.accelerator.prepare(self.model)
            self.model_wrapped = self.deepspeed = model
3162

3163
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
3164

3165
3166
3167
3168
3169
3170
3171
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3172

3173
        batch_size = self.args.eval_batch_size
3174

3175
        logger.info(f"***** Running {description} *****")
3176
        if has_length(dataloader):
3177
3178
3179
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
3180
        logger.info(f"  Batch size = {batch_size}")
3181

Julien Chaumond's avatar
Julien Chaumond committed
3182
3183
        model.eval()

3184
3185
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
3186
        eval_dataset = getattr(dataloader, "dataset", None)
3187

3188
        if is_torch_tpu_available():
3189
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3190

3191
        if args.past_index >= 0:
3192
            self._past = None
3193

3194
3195
3196
3197
3198
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
3199
3200
        inputs_host = None

3201
3202
3203
3204
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
3205
        all_inputs = None
3206
3207
3208
3209
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
3210
        for step, inputs in enumerate(dataloader):
3211
3212
3213
3214
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
3215
3216
3217
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
3218
3219

            # Prediction step
3220
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3221
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3222

3223
3224
3225
            if is_torch_tpu_available():
                xm.mark_step()

3226
            # Update containers on host
3227
            if loss is not None:
3228
                losses = self._nested_gather(loss.repeat(batch_size))
3229
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
3230
            if labels is not None:
3231
                labels = self._pad_across_processes(labels)
3232
3233
3234
3235
3236
3237
3238
3239
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3240
3241
3242
3243
            if logits is not None:
                logits = self._pad_across_processes(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
3244
                logits = self._nested_gather(logits)
3245
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
3246
3247
3248
            if labels is not None:
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3249
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
3250

3251
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3252
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3253
3254
3255
3256
3257
3258
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3259
3260
3261
3262
3263
3264
3265
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3266
3267
3268
3269
3270
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3271
3272

                # Set back to None to begin a new accumulation
3273
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3274

3275
        if args.past_index and hasattr(self, "_past"):
3276
3277
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3278

3279
        # Gather all remaining tensors and put them back on the CPU
3280
3281
3282
3283
3284
3285
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3286
3287
3288
3289
3290
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
3291
3292
3293
3294
3295
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
3296
        if has_length(eval_dataset):
3297
            num_samples = len(eval_dataset)
3298
3299
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3300
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3301
3302
            num_samples = eval_dataset.num_examples
        else:
3303
3304
3305
3306
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3307
3308
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3309
3310
3311
3312
3313
3314
3315
3316
3317

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
3318
3319
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
3320
3321
3322

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3323
3324
3325
3326
3327
3328
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3329
3330
        else:
            metrics = {}
3331

3332
3333
3334
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3335
3336
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3337
3338
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
3339

3340
        # Prefix all keys with metric_key_prefix + '_'
3341
        for key in list(metrics.keys()):
3342
3343
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3344

3345
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3346

3347
    def _nested_gather(self, tensors, name=None):
3348
3349
3350
3351
3352
3353
3354
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3355
3356
            if name is None:
                name = "nested_gather"
3357
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3358
3359
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
Zachary Mueller's avatar
Zachary Mueller committed
3360
3361
3362
        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
            self.args.distributed_state is None and self.local_rank != -1
        ):
3363
            tensors = distributed_concat(tensors)
3364
        return tensors
3365

3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
3388
3389
3390
        # When extracting XLA graphs for compilation, max_size is 0,
        # so use inequality to avoid errors.
        if tensor.shape[1] >= max_size:
3391
3392
3393
3394
3395
3396
3397
3398
3399
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3400

3401
    def prediction_step(
3402
3403
3404
3405
3406
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3407
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3408
        """
Stas Bekman's avatar
Stas Bekman committed
3409
        Perform an evaluation step on `model` using `inputs`.
3410
3411
3412
3413

        Subclass and override to inject custom behavior.

        Args:
3414
            model (`nn.Module`):
3415
                The model to evaluate.
3416
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3417
3418
3419
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3420
3421
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3422
                Whether or not to return the loss only.
3423
            ignore_keys (`List[str]`, *optional*):
3424
3425
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3426
3427

        Return:
3428
3429
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3430
        """
3431
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
3432
3433
3434
3435
3436
3437
3438
3439
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

3440
        inputs = self._prepare_inputs(inputs)
3441
3442
3443
3444
3445
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3446

3447
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
3448
        if has_labels or loss_without_labels:
3449
3450
3451
3452
3453
3454
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3455
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3456
3457
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
3458
                if has_labels or loss_without_labels:
Sylvain Gugger's avatar
Sylvain Gugger committed
3459
3460
3461
3462
3463
3464
3465
3466
3467
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3468
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3469
3470
3471
3472
3473
3474
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3475
            else:
3476
                if has_labels or loss_without_labels:
3477
                    with self.compute_loss_context_manager():
3478
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3479
                    loss = loss.mean().detach()
3480

Sylvain Gugger's avatar
Sylvain Gugger committed
3481
3482
3483
3484
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3485
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3486
                    loss = None
3487
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3488
3489
3490
3491
3492
3493
3494
3495
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3496
3497
3498
3499

        if prediction_loss_only:
            return (loss, None, None)

3500
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3501
3502
3503
3504
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3505
3506
3507

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3508
3509
3510
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3511
3512

        Args:
3513
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3514
3515
3516
                The inputs and targets of the model.

        Returns:
3517
            `int`: The number of floating-point operations.
3518
        """
3519
3520
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3521
3522
        else:
            return 0
3523

3524
    def init_git_repo(self, at_init: bool = False):
3525
        """
3526
        Initializes a git repo in `self.args.hub_model_id`.
3527
3528
3529
3530
3531
3532

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3533
        """
3534
        if not self.is_world_process_zero():
3535
            return
3536
        if self.args.hub_model_id is None:
3537
            repo_name = Path(self.args.output_dir).absolute().name
3538
3539
        else:
            repo_name = self.args.hub_model_id
3540
3541
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3542

3543
3544
        # Make sure the repo exists.
        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
3545
        try:
3546
            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3547
        except EnvironmentError:
3548
            if self.args.overwrite_output_dir and at_init:
3549
3550
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
3551
                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3552
3553
3554
3555
            else:
                raise

        self.repo.git_pull()
3556
3557

        # By default, ignore the checkpoint folders
3558
3559
3560
3561
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3562
3563
3564
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3565
3566
3567
3568
        # Add "*.sagemaker" to .gitignore if using SageMaker
        if os.environ.get("SM_TRAINING_ENV"):
            self._add_sm_patterns_to_gitignore()

3569
3570
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3571
3572
3573
3574
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3575
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3576
3577
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3578
3579
3580
3581
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3582
    ):
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3608
3609
3610
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3611
3612
3613
3614
3615
3616
3617
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3618
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3619
3620
3621
3622
3623
3624
3625
3626
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
3637
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3661
3662
3663
3664
            push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
            # Return type of `Repository.push_to_hub` is either None or a tuple.
            if push_work is not None:
                self.push_in_progress = push_work[1]
3665
3666
        except Exception as e:
            logger.error(f"Error when pushing to hub: {e}")
3667
3668
3669
3670
3671
3672
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3673
        """
3674
        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3675
3676

        Parameters:
3677
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3678
                Message to commit while pushing.
3679
3680
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3681
            kwargs:
3682
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3683
3684

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3685
3686
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3687
        """
3688
3689
3690
3691
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3692

3693
3694
        model_name = kwargs.pop("model_name", None)
        if model_name is None and self.args.should_save:
3695
3696
3697
3698
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3699

3700
3701
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3702
        self.save_model(_internal_call=True)
3703
3704
3705
3706
3707

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3708
3709
3710
3711
3712
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3713
3714
3715
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3716
3717
3718
3719
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3720
3721
3722
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3723
3724
3725
3726
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3727

3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3739
    ) -> EvalLoopOutput:
3740
        """
3741
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3742
3743
3744

        Works both with or without labels.
        """
3745
3746
        args = self.args

3747
3748
3749
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3750
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3751
3752

        # if eval is called w/o train init deepspeed here
3753
3754
3755
3756
        if self.is_deepspeed_enabled and self.model_wrapped is self.model:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
            model = self.accelerator.prepare(self.model)
            self.model_wrapped = self.deepspeed = model
3757

3758
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3759

3760
3761
3762
3763
3764
3765
3766
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3767
3768
3769
3770
3771
3772
3773
3774
3775

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3776
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3777

3778
        world_size = max(1, args.world_size)
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3789
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3790
3791
3792
3793

        model.eval()

        if is_torch_tpu_available():
3794
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3795

3796
        if args.past_index >= 0:
3797
3798
3799
3800
3801
3802
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3803
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3804

3805
3806
3807
3808
3809
3810
3811
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3812
3813
3814
3815
3816
3817
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3818
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3819
3820

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3821
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3822
3823
3824
3825
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3826
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3827
3828

                # Set back to None to begin a new accumulation
3829
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3830

3831
        if args.past_index and hasattr(self, "_past"):
3832
3833
3834
3835
3836
3837
3838
3839
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3840
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3841
3842
3843
3844

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3845
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3846
3847

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3848
3849
3850
3851
3852
3853
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

3868
        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3881
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3882
3883
3884
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
3885
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923

    def _add_sm_patterns_to_gitignore(self) -> None:
        """Add SageMaker Checkpointing patterns to .gitignore file."""
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return

        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Get current .gitignore content
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Add the patterns to .gitignore
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write the .gitignore file if it has changed
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        self.repo.git_add(".gitignore")

        # avoid race condition with git status
        time.sleep(0.5)

        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()