trainer.py 190 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34

35
36
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
37

38
# Integrations must be imported before ML frameworks:
39
40
# isort: off
from .integrations import (
41
    default_hp_search_backend,
42
    get_reporting_integration_callbacks,
43
    hp_params,
44
    is_fairscale_available,
45
    is_optuna_available,
46
    is_ray_tune_available,
47
    is_sigopt_available,
48
    is_wandb_available,
49
50
    run_hp_search_optuna,
    run_hp_search_ray,
51
    run_hp_search_sigopt,
52
    run_hp_search_wandb,
53
)
54

55
56
# isort: on

57
58
import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
59
import torch.distributed as dist
60
from huggingface_hub import Repository, create_repo
61
62
from packaging import version
from torch import nn
63
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
64
65
from torch.utils.data.distributed import DistributedSampler

66
67
from . import __version__
from .configuration_utils import PretrainedConfig
68
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
69
from .debug_utils import DebugOption, DebugUnderflowOverflow
70
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
71
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
72
from .modelcard import TrainingSummary
73
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
74
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
75
from .optimization import Adafactor, get_scheduler
76
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
77
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
78
79
80
81
82
83
84
85
86
87
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
88
    DistributedLengthGroupedSampler,
89
    DistributedSamplerWithLoop,
90
    DistributedTensorGatherer,
91
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
92
    LabelSmoother,
93
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
94
    SequentialDistributedSampler,
95
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
96
97
    distributed_broadcast_scalars,
    distributed_concat,
98
    find_batch_size,
99
    get_model_param_count,
100
    get_module_class_from_name,
101
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
102
103
104
    nested_concat,
    nested_detach,
    nested_numpify,
105
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
106
107
108
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
109
110
111
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
112
    EvalLoopOutput,
113
    EvalPrediction,
114
    FSDPOption,
115
    HPSearchBackend,
116
117
    HubStrategy,
    IntervalStrategy,
118
    PredictionOutput,
119
    RemoveColumnsCollator,
120
    ShardedDDPOption,
121
    TrainerMemoryTracker,
122
123
124
    TrainOutput,
    default_compute_objective,
    default_hp_space,
125
    denumpify_detensorize,
126
    enable_full_determinism,
127
    find_executable_batch_size,
128
    get_last_checkpoint,
129
    has_length,
130
    number_of_arguments,
131
    seed_worker,
132
    set_seed,
133
    speed_metrics,
134
)
135
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
136
137
from .utils import (
    CONFIG_NAME,
138
139
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
140
    WEIGHTS_INDEX_NAME,
141
    WEIGHTS_NAME,
142
    can_return_loss,
143
    find_labels,
144
    get_full_repo_name,
145
    is_accelerate_available,
146
147
148
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
149
    is_ipex_available,
150
    is_safetensors_available,
151
152
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
153
    is_torch_compile_available,
154
    is_torch_neuroncore_available,
155
156
    is_torch_tpu_available,
    logging,
157
    strtobool,
158
)
159
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
160
161


162
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
163

Sylvain Gugger's avatar
Sylvain Gugger committed
164
DEFAULT_CALLBACKS = [DefaultFlowCallback]
165
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
166

167
168
169
170
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
171

172
173
if is_apex_available():
    from apex import amp
174

175
176
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
177

178
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
179
180
181
182
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

183
if is_fairscale_available():
184
    dep_version_check("fairscale")
185
    import fairscale
186
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
187
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
188
    from fairscale.nn.wrap import auto_wrap
189
190
191
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

192

Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
195
196
197
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
198
199

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
200
201
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
202

203

204
205
206
207
if is_safetensors_available():
    import safetensors.torch


208
209
210
211
212
213
214
skip_first_batches = None
if is_accelerate_available():
    from accelerate import __version__ as accelerate_version

    if version.parse(accelerate_version) >= version.parse("0.16"):
        from accelerate import skip_first_batches

215
216
    from accelerate import Accelerator

217

218
219
220
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
221
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
222
223


224
225
226
227
228
229
230
231
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
232
233
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
234
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
235
236

    Args:
237
238
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
239

240
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
241

Sylvain Gugger's avatar
Sylvain Gugger committed
242
243
244
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
            models.
245
246
247
248

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
249
250
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
251
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
252
253
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
254
255
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
256
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
257
258
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
259
260
261
262
263
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
264
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
265
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
266
267
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
268
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
269
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
270
271
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
272
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
273
274
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
275

276
277
278
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
279
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
280
281
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
282
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
283
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
284
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
285

286
287
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
288
289
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
290
291
292
293
294
295
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
296

297
298
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
299
300
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
301
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
302
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
303
304
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
305
306
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
307
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
308
309
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
310
311
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
312

Julien Chaumond's avatar
Julien Chaumond committed
313
314
    """

315
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
316

Julien Chaumond's avatar
Julien Chaumond committed
317
318
    def __init__(
        self,
319
        model: Union[PreTrainedModel, nn.Module] = None,
320
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
321
322
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
323
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
324
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
325
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
326
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
327
        callbacks: Optional[List[TrainerCallback]] = None,
328
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
329
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
330
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
331
        if args is None:
332
333
334
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
335
336
        self.args = args
        # Seed must be set before instantiating the model when using model
337
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
338
        self.hp_name = None
339
        self.deepspeed = None
340
        self.is_in_train = False
341

342
343
344
        # create accelerator object
        self.accelerator = Accelerator()

345
346
347
348
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

349
        # set the correct log level depending on the node
350
        log_level = args.get_process_log_level()
351
352
        logging.set_verbosity(log_level)

353
354
355
        # force device and distributed setup init explicitly
        args._setup_devices

356
357
358
359
360
361
362
363
364
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
365
366
367
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
368
369
370
                    FutureWarning,
                )
            self.model_init = model_init
371

372
373
374
375
376
377
378
379
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

380
381
382
383
384
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

385
386
387
388
389
390
391
392
393
394
395
396
397
        if (
            getattr(model, "hf_device_map", None) is not None
            and len([device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]) > 1
            and not self.is_model_parallel
        ):
            self.is_model_parallel = True

            # warn users
            logger.info(
                "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
                " to `True` to avoid any unexpected behavior such as device placement mismatching."
            )

398
        # At this stage the model is already loaded
399
400
        if getattr(model, "is_loaded_in_kbit", False):
            if getattr(model, "_is_kbit_training_enabled", False):
401
402
403
404
405
406
407
408
409
410
411
                logger.info(
                    "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
                    " inside the model such as adapters using `peft` library and freeze the model weights. Please"
                    " check "
                    " the examples in https://github.com/huggingface/peft for more details."
                )
            else:
                raise ValueError(
                    "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit"
                    " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
                )
412

413
414
415
416
417
418
419
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
420
421
422
423
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
424
            if args.parallel_mode != ParallelMode.DISTRIBUTED:
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

440
441
442
443
444
445
        self.fsdp = None
        if len(args.fsdp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
446
            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
447
448
                raise ValueError("Using fsdp only works in distributed training.")

449
450
451
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
452
            # below is the current alternative.
453
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
454
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
455

456
            from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
457
458
459
460
461

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
462
463
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
464

465
            self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
466
467
468
            if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
                "backward_prefetch", []
            ):
469
470
                self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

Seung-Moo Yang's avatar
Seung-Moo Yang committed
471
472
473
            self.forward_prefetch = False
            if self.args.fsdp_config.get("forward_prefect", False):
                self.forward_prefetch = True
474

475
476
477
478
            self.limit_all_gathers = False
            if self.args.fsdp_config.get("limit_all_gathers", False):
                self.limit_all_gathers = True

479
        # one place to sort out whether to place the model on device or not
480
481
482
483
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
484
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
485
        # 4. Sharded DDP - same as MP
486
        # 5. FSDP - same as MP
487
        self.place_model_on_device = args.place_model_on_device
488
489
        if (
            self.is_model_parallel
490
            or args.deepspeed
491
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
492
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
493
            or (self.fsdp is not None)
494
        ):
495
496
            self.place_model_on_device = False

497
498
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
499
500
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
501
        self.tokenizer = tokenizer
502

503
        if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
Sylvain Gugger's avatar
Sylvain Gugger committed
504
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
505
506
507

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
508
            self.args._n_gpu = 1
509
510
511
512
513

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
514
        self.compute_metrics = compute_metrics
515
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
516
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
517
518
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
519
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
520
521
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
537
        if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
538
539
540
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
541
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
542
543
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
544
545
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
546
547
548
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
549
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
550

551
552
553
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

554
555
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
556
            self.init_git_repo(at_init=True)
557
558
559
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
560
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
561
562
                dist.barrier()

563
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
564
            os.makedirs(self.args.output_dir, exist_ok=True)
565

566
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
567
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
568

569
570
571
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

572
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
573
574
575
576
            raise ValueError(
                "The train_dataset does not implement __len__, max_steps has to be specified. "
                "The number of steps needs to be known in advance for the learning rate scheduler."
            )
577

578
579
580
581
582
583
584
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

585
        self._signature_columns = None
586

587
588
        # Mixed precision setup
        self.use_apex = False
589
590
        self.use_cuda_amp = False
        self.use_cpu_amp = False
591

592
593
594
595
596
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
614

615
        if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
616
            if args.half_precision_backend == "auto":
617
                if args.device == torch.device("cpu"):
618
619
620
621
622
623
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
624
                else:
625
                    args.half_precision_backend = "cuda_amp"
626

627
            logger.info(f"Using {args.half_precision_backend} half precision backend")
628

629
        self.do_grad_scaling = False
630
        if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
631
            # deepspeed and SageMaker Model Parallel manage their own half precision
632
633
634
635
636
637
638
639
640
641
642
643
644
            if self.sharded_ddp is not None:
                if args.half_precision_backend == "cuda_amp":
                    self.use_cuda_amp = True
                    self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                    #  bf16 does not need grad scaling
                    self.do_grad_scaling = self.amp_dtype == torch.float16
                    if self.do_grad_scaling:
                        if self.sharded_ddp is not None:
                            self.scaler = ShardedGradScaler()
                        elif self.fsdp is not None:
                            from torch.distributed.fsdp.sharded_grad_scaler import (
                                ShardedGradScaler as FSDPShardedGradScaler,
                            )
645

646
647
648
                            self.scaler = FSDPShardedGradScaler()
                        elif is_torch_tpu_available():
                            from torch_xla.amp import GradScaler
649

650
651
652
653
654
655
656
                            self.scaler = GradScaler()
                        else:
                            self.scaler = torch.cuda.amp.GradScaler()
                elif args.half_precision_backend == "cpu_amp":
                    self.use_cpu_amp = True
                    self.amp_dtype = torch.bfloat16
            elif args.half_precision_backend == "apex":
657
658
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
659
660
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
661
662
663
                    )
                self.use_apex = True

664
665
666
667
668
669
670
671
672
673
674
675
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
676
677
678
679
680
681
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

682
683
684
685
686
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
687
        self.control = TrainerControl()
688
689
690
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
691
        self.hp_search_backend = None
692
        self.use_tune_checkpoints = False
693
        default_label_names = find_labels(self.model.__class__)
694
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
695
        self.can_return_loss = can_return_loss(self.model.__class__)
Sylvain Gugger's avatar
Sylvain Gugger committed
696
697
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

698
699
700
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

701
702
703
        # very last
        self._memory_tracker.stop_and_update_metrics()

704
705
        # torch.compile
        if args.torch_compile and not is_torch_compile_available():
706
            raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
707

Sylvain Gugger's avatar
Sylvain Gugger committed
708
709
    def add_callback(self, callback):
        """
710
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
711
712

        Args:
713
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
714
715
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
716
717
718
719
720
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
721
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
722

723
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
724
725

        Args:
726
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
727
728
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
729
730

        Returns:
731
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
732
733
734
735
736
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
737
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
738
739

        Args:
740
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
741
742
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
743
744
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
745

Sylvain Gugger's avatar
Sylvain Gugger committed
746
747
748
749
750
751
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

752
    def _set_signature_columns_if_needed(self):
753
754
755
756
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
757
758
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
759

760
761
762
763
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
764
        signature_columns = self._signature_columns
765
766

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
767
        if len(ignored_columns) > 0:
768
            dset_description = "" if description is None else f"in the {description} set"
769
770
771
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
772
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
773
                " you can safely ignore this message."
774
            )
775

776
        columns = [k for k in signature_columns if k in dataset.column_names]
777

778
779
780
781
782
783
784
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
785

786
787
788
789
790
791
792
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
793
        signature_columns = self._signature_columns
794
795
796
797
798
799
800
801
802
803

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

804
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
805
        if self.train_dataset is None or not has_length(self.train_dataset):
806
            return None
807

808
        generator = None
809
        if self.args.world_size <= 1:
810
            generator = torch.Generator()
811
812
813
814
815
816
817
818
819
820
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
821

822
823
        # Build the sampler.
        if self.args.group_by_length:
824
825
826
827
828
829
830
831
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
832
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
833
            if self.args.world_size <= 1:
834
                return LengthGroupedSampler(
835
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
836
                    dataset=self.train_dataset,
837
838
839
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
840
                )
841
842
            else:
                return DistributedLengthGroupedSampler(
843
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
844
                    dataset=self.train_dataset,
845
846
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
847
                    lengths=lengths,
848
                    model_input_name=model_input_name,
849
                    seed=seed,
850
851
852
                )

        else:
853
            if self.args.world_size <= 1:
854
                return RandomSampler(self.train_dataset, generator=generator)
Sylvain Gugger's avatar
Sylvain Gugger committed
855
856
857
858
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
859
860
861
862
863
864
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
865
                    seed=seed,
866
                )
867
            else:
868
                return DistributedSampler(
869
870
871
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
872
                    seed=seed,
873
                )
874
875
876

    def get_train_dataloader(self) -> DataLoader:
        """
877
        Returns the training [`~torch.utils.data.DataLoader`].
878

879
880
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
881
882
883
884
885

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
886

887
        train_dataset = self.train_dataset
888
        data_collator = self.data_collator
889
890
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
891
892
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
893

894
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
895
896
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
897
                    train_dataset,
898
                    batch_size=self._train_batch_size,
899
900
901
902
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
903

904
905
            return DataLoader(
                train_dataset,
906
                batch_size=self._train_batch_size,
907
                collate_fn=data_collator,
908
909
910
911
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

912
913
914
        train_sampler = self._get_train_sampler()

        return DataLoader(
915
            train_dataset,
916
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
917
            sampler=train_sampler,
918
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
919
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
920
            num_workers=self.args.dataloader_num_workers,
921
            pin_memory=self.args.dataloader_pin_memory,
922
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
923
924
        )

925
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
926
927
928
929
930
931
932
933
934
935
936
937
938
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
939
            elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
940
941
942
943
944
945
946
947
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
948
949
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
950
951
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
952
            )
Lysandre Debut's avatar
Lysandre Debut committed
953

Julien Chaumond's avatar
Julien Chaumond committed
954
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
955
        """
956
        Returns the evaluation [`~torch.utils.data.DataLoader`].
957

958
959
        Subclass and override this method if you want to inject some custom behavior.

960
        Args:
961
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
962
963
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
964
        """
Julien Chaumond's avatar
Julien Chaumond committed
965
966
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
967
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
968
        data_collator = self.data_collator
969

970
971
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
972
973
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
974

975
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
976
977
978
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
979
                    batch_size=self.args.per_device_eval_batch_size,
980
981
982
983
984
985
986
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
987
                collate_fn=data_collator,
988
989
990
991
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

992
        eval_sampler = self._get_eval_sampler(eval_dataset)
993

994
        return DataLoader(
995
            eval_dataset,
996
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
997
            batch_size=self.args.eval_batch_size,
998
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
999
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
1000
            num_workers=self.args.dataloader_num_workers,
1001
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1002
1003
1004
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
1005
        """
1006
        Returns the test [`~torch.utils.data.DataLoader`].
1007

1008
1009
        Subclass and override this method if you want to inject some custom behavior.

1010
        Args:
1011
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1012
1013
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
1014
        """
1015
1016
        data_collator = self.data_collator

1017
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
1018
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
1019
1020
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
1021

1022
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
1034
                collate_fn=data_collator,
1035
1036
1037
1038
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

1039
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
1040

1041
1042
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
1043
            test_dataset,
1044
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1045
            batch_size=self.args.eval_batch_size,
1046
            collate_fn=data_collator,
1047
            drop_last=self.args.dataloader_drop_last,
1048
            num_workers=self.args.dataloader_num_workers,
1049
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1050
        )
Lysandre Debut's avatar
Lysandre Debut committed
1051

1052
    def create_optimizer_and_scheduler(self, num_training_steps: int):
1053
1054
1055
        """
        Setup the optimizer and the learning rate scheduler.

1056
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
1057
1058
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
1059
1060
        """
        self.create_optimizer()
1061
1062
1063
1064
1065
1066
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
1067
1068
1069
1070
1071

    def create_optimizer(self):
        """
        Setup the optimizer.

1072
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
1073
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1074
        """
1075
1076
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

1077
        if self.optimizer is None:
1078
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
1079
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
1080
1081
            optimizer_grouped_parameters = [
                {
1082
1083
1084
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
1085
1086
1087
                    "weight_decay": self.args.weight_decay,
                },
                {
1088
1089
1090
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
1091
1092
1093
                    "weight_decay": 0.0,
                },
            ]
1094
1095
1096

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

1097
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1098
1099
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
1100
1101
                    optim=optimizer_cls,
                    **optimizer_kwargs,
1102
1103
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1104
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1105
1106
1107
1108
1109
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

Stas Bekman's avatar
Stas Bekman committed
1110
                    skipped = 0
1111
                    for module in opt_model.modules():
1112
                        if isinstance(module, nn.Embedding):
1113
                            skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
1114
                            logger.info(f"skipped {module}: {skipped/2**20}M params")
1115
1116
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
1117
                    logger.info(f"skipped: {skipped/2**20}M params")
Sylvain Gugger's avatar
Sylvain Gugger committed
1118

Sylvain Gugger's avatar
Sylvain Gugger committed
1119
1120
1121
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1122
1123
        return self.optimizer

1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
1134
1135
1136
1137
1138
1139
1140
1141

        # parse args.optim_args
        optim_args = {}
        if args.optim_args:
            for mapping in args.optim_args.replace(" ", "").split(","):
                key, value = mapping.split("=")
                optim_args[key] = value

1142
        optimizer_kwargs = {"lr": args.learning_rate}
1143

1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1156
        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
1157
1158
1159
1160
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1161
1162
            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
                optimizer_kwargs.update({"fused": True})
1163
1164
1165
1166
1167
1168
1169
1170
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1171
1172
1173
1174
1175
1176
1177
1178
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
        elif args.optim in [
            OptimizerNames.ADAMW_BNB,
            OptimizerNames.ADAMW_8BIT,
            OptimizerNames.PAGED_ADAMW,
            OptimizerNames.PAGED_ADAMW_8BIT,
            OptimizerNames.LION,
            OptimizerNames.LION_8BIT,
            OptimizerNames.PAGED_LION,
            OptimizerNames.PAGED_LION_8BIT,
        ]:
            try:
                from bitsandbytes.optim import AdamW, Lion

                is_paged = False
                optim_bits = 32
                optimizer_cls = None
                additional_optim_kwargs = adam_kwargs
                if "paged" in args.optim:
                    is_paged = True
                if "8bit" in args.optim:
                    optim_bits = 8
                if "adam" in args.optim:
                    optimizer_cls = AdamW
                elif "lion" in args.optim:
                    optimizer_cls = Lion
                    additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}

                bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
                optimizer_kwargs.update(additional_optim_kwargs)
                optimizer_kwargs.update(bnb_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
1211
1212
1213
1214
1215
1216
1217
1218
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
            try:
                from torchdistx.optimizers import AnyPrecisionAdamW

                optimizer_cls = AnyPrecisionAdamW
                optimizer_kwargs.update(adam_kwargs)

                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
                optimizer_kwargs.update(
                    {
                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
                        "compensation_buffer_dtype": getattr(
                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
                        ),
                    }
                )
            except ImportError:
                raise ValueError("Please install https://github.com/pytorch/torchdistx")
1239
1240
1241
1242
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1243
1244
1245
1246
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1247
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1248
        """
1249
1250
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1251
1252
1253
1254

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1255
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1256
1257
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1258
                optimizer=self.optimizer if optimizer is None else optimizer,
1259
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1260
                num_training_steps=num_training_steps,
1261
            )
1262
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1263

1264
    def num_examples(self, dataloader: DataLoader) -> int:
1265
        """
1266
1267
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1268
        """
1269
        try:
1270
1271
1272
1273
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1274
1275
1276
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1277

1278
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1279
        """HP search setup code"""
1280
1281
        self._trial = trial

1282
1283
        if self.hp_search_backend is None or trial is None:
            return
1284
1285
1286
1287
1288
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1289
1290
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1291
1292
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1293

1294
1295
        for key, value in params.items():
            if not hasattr(self.args, key):
1296
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1297
1298
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1299
                )
1300
                continue
1301
1302
1303
1304
1305
1306
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1307
            logger.info(f"Trial: {trial.params}")
1308
1309
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1310
1311
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1312
1313
        if self.args.deepspeed:
            # Rebuild the deepspeed config to reflect the updated training parameters
1314
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1315

1316
1317
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1318

1319
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1320
1321
        if self.hp_search_backend is None or trial is None:
            return
1322
        self.objective = self.compute_objective(metrics.copy())
1323
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1324
1325
            import optuna

1326
            trial.report(self.objective, step)
1327
            if trial.should_prune():
1328
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1329
1330
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1331
1332
            from ray import tune

1333
            if self.control.should_save:
1334
                self._tune_save_checkpoint()
1335
1336
            tune.report(objective=self.objective, **metrics)

1337
    def _tune_save_checkpoint(self):
1338
1339
        from ray import tune

1340
1341
        if not self.use_tune_checkpoints:
            return
1342
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1343
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1344
            self.save_model(output_dir, _internal_call=True)
1345
            if self.args.should_save:
1346
1347
1348
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1349

1350
    def call_model_init(self, trial=None):
1351
        model_init_argcount = number_of_arguments(self.model_init)
1352
1353
1354
1355
1356
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1357
1358
1359
1360
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1361
1362
1363

        return model

1364
1365
1366
1367
1368
1369
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            example_batch = next(iter(dataloader))
1370
            example_batch = self._prepare_inputs(example_batch)
1371
1372
            try:
                jit_model = model.eval()
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
                with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
1390
                jit_model = torch.jit.freeze(jit_model)
1391
1392
1393
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
1394
                model = jit_model
1395
1396
1397
                self.use_cpu_amp = False
                self.use_cuda_amp = False
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
1398
1399
1400
1401
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1402
1403
1404
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1405
1406
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1407
1408
1409
1410
1411
1412
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
1413
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
1414
            # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
1415
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
1416
1417
1418
        else:
            if not model.training:
                model.train()
1419
1420
1421
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1422
1423
1424

        return model

1425
    def _wrap_model(self, model, training=True, dataloader=None):
1426
1427
1428
1429
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

Sylvain Gugger's avatar
Sylvain Gugger committed
1430
1431
1432
1433
1434
1435
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1436
1437
        # already initialized its own DDP and AMP
        if self.deepspeed:
1438
            return self.deepspeed
1439

1440
1441
1442
1443
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1444
1445
1446
1447
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

1448
1449
        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
1450
            model = nn.DataParallel(model)
1451

1452
        if self.args.jit_mode_eval:
1453
            start_time = time.time()
1454
            model = self.torch_jit_model_eval(model, dataloader, training)
1455
            self.jit_compilation_time = round(time.time() - start_time, 4)
1456

1457
1458
1459
1460
1461
1462
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1463
1464
1465
1466
1467
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1468
                mixed_precision = self.args.fp16 or self.args.bf16
1469
1470
1471
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1472
1473
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1474
                self.model = model = FullyShardedDDP(
1475
1476
1477
1478
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1479
                ).to(self.args.device)
1480
        # Distributed training using PyTorch FSDP
1481
        elif self.fsdp is not None:
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
            if not self.args.fsdp_config["xla"]:
                # PyTorch FSDP!
                from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

                if FSDPOption.OFFLOAD in self.args.fsdp:
                    cpu_offload = CPUOffload(offload_params=True)
                else:
                    cpu_offload = CPUOffload(offload_params=False)
1492

1493
                auto_wrap_policy = None
1494

1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
                if FSDPOption.AUTO_WRAP in self.args.fsdp:
                    if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                        auto_wrap_policy = functools.partial(
                            size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                        )
                    elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                        transformer_cls_to_wrap = set()
                        for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                            transformer_cls = get_module_class_from_name(model, layer_class)
                            if transformer_cls is None:
                                raise Exception("Could not find the transformer layer class to wrap in the model.")
                            else:
                                transformer_cls_to_wrap.add(transformer_cls)
                        auto_wrap_policy = functools.partial(
                            transformer_auto_wrap_policy,
                            # Transformer layer class to wrap
                            transformer_layer_cls=transformer_cls_to_wrap,
                        )
                mixed_precision_policy = None
                dtype = None
                if self.args.fp16:
                    dtype = torch.float16
                elif self.args.bf16:
                    dtype = torch.bfloat16
                if dtype is not None:
                    mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
                if type(model) != FSDP:
                    # XXX: Breaking the self.model convention but I see no way around it for now.
1523
1524
1525
1526
1527
                    signature = inspect.signature(FSDP.__init__).parameters.keys()
                    kwargs = {}
                    for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
                        if arg in signature:
                            kwargs[arg] = getattr(self, arg)
1528
1529
1530
1531
1532
1533
1534
                    self.model = model = FSDP(
                        model,
                        sharding_strategy=self.fsdp,
                        cpu_offload=cpu_offload,
                        auto_wrap_policy=auto_wrap_policy,
                        mixed_precision=mixed_precision_policy,
                        device_id=self.args.device,
1535
                        **kwargs,
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
                    )
            else:
                try:
                    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                    from torch_xla.distributed.fsdp import checkpoint_module
                    from torch_xla.distributed.fsdp.wrap import (
                        size_based_auto_wrap_policy,
                        transformer_auto_wrap_policy,
                    )
                except ImportError:
                    raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
                auto_wrap_policy = None
                auto_wrapper_callable = None
1549
                if self.args.fsdp_config["fsdp_min_num_params"] > 0:
1550
                    auto_wrap_policy = functools.partial(
1551
                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
1552
                    )
1553
1554
1555
1556
1557
1558
1559
1560
                elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                    transformer_cls_to_wrap = set()
                    for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                        transformer_cls = get_module_class_from_name(model, layer_class)
                        if transformer_cls is None:
                            raise Exception("Could not find the transformer layer class to wrap in the model.")
                        else:
                            transformer_cls_to_wrap.add(transformer_cls)
1561
1562
1563
                    auto_wrap_policy = functools.partial(
                        transformer_auto_wrap_policy,
                        # Transformer layer class to wrap
1564
                        transformer_layer_cls=transformer_cls_to_wrap,
1565
                    )
1566
1567
1568
1569
1570
1571
1572
                fsdp_kwargs = self.args.xla_fsdp_config
                if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                    # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                    def auto_wrapper_callable(m, *args, **kwargs):
                        return FSDP(checkpoint_module(m), *args, **kwargs)

                # Wrap the base model with an outer FSDP wrapper
1573
                self.model = model = FSDP(
1574
1575
                    model,
                    auto_wrap_policy=auto_wrap_policy,
1576
1577
                    auto_wrapper_callable=auto_wrapper_callable,
                    **fsdp_kwargs,
1578
                )
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588

                # Patch `xm.optimizer_step` should not reduce gradients in this case,
                # as FSDP does not need gradient reduction over sharded parameters.
                def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                    loss = optimizer.step(**optimizer_args)
                    if barrier:
                        xm.mark_step()
                    return loss

                xm.optimizer_step = patched_optimizer_step
Sylvain Gugger's avatar
Sylvain Gugger committed
1589
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1590
1591
1592
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1593
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
1594
            kwargs = {}
1595
            if self.args.ddp_find_unused_parameters is not None:
1596
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1597
1598
1599
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1600
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1601
            else:
1602
1603
1604
1605
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1606
            if is_torch_neuroncore_available():
1607
                return model
Wing Lian's avatar
Wing Lian committed
1608
1609
1610
1611
1612
1613
1614
            if any(p.requires_grad for p in model.parameters()):
                model = nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                    output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
                    **kwargs,
                )
1615

1616
1617
1618
1619
1620
        # torch.compile() needs to be called after wrapping the model with FSDP or DDP
        # to ensure that it accounts for the graph breaks required by those wrappers
        if self.args.torch_compile:
            model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)

1621
1622
        return model

1623
1624
    def train(
        self,
1625
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1626
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1627
        ignore_keys_for_eval: Optional[List[str]] = None,
1628
        **kwargs,
1629
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1630
1631
1632
1633
        """
        Main training entry point.

        Args:
1634
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1635
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1636
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1637
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1638
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1639
                The trial run or the hyperparameter dictionary for hyperparameter search.
1640
            ignore_keys_for_eval (`List[str]`, *optional*)
1641
1642
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1643
1644
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1645
        """
1646
1647
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1648
1649
1650
1651

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1652
1653
        args = self.args

1654
1655
        self.is_in_train = True

1656
1657
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1658
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1659
            self._move_model_to_device(self.model, args.device)
1660

1661
1662
1663
1664
1665
1666
1667
1668
1669
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1670
1671
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1672
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1673

1674
        # Model re-init
1675
        model_reloaded = False
1676
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1677
            # Seed must be set before instantiating the model when using model_init.
1678
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1679
1680
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1681
1682
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1683

1684
        # Load potential model checkpoint
1685
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1686
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1687
            if resume_from_checkpoint is None:
1688
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1689

1690
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None:
1691
            self._load_from_checkpoint(resume_from_checkpoint)
1692

1693
1694
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1695
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1696
                self._move_model_to_device(self.model, args.device)
1697
1698
            self.model_wrapped = self.model

1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self._train_batch_size = batch_size
1713
        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
1714
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1715
        train_dataloader = self.get_train_dataloader()
1716
1717
1718
1719
1720

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1721
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1722
1723
1724
1725
1726

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1727
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1728
            num_examples = self.num_examples(train_dataloader)
1729
1730
1731
1732
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1733
                )
1734
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1735
1736
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1737
            else:
1738
1739
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1740
1741
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1742
            max_steps = args.max_steps
1743
1744
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1745
            num_update_steps_per_epoch = max_steps
1746
            num_examples = total_train_batch_size * args.max_steps
1747
            num_train_samples = args.max_steps * total_train_batch_size
1748
1749
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1750
1751
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1752
            )
Julien Chaumond's avatar
Julien Chaumond committed
1753

1754
1755
1756
1757
1758
1759
1760
1761
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps and args.logging_steps < 1:
            args.logging_steps = math.ceil(max_steps * args.logging_steps)
        if args.eval_steps and args.eval_steps < 1:
            args.eval_steps = math.ceil(max_steps * args.eval_steps)
        if args.save_steps and args.save_steps < 1:
            args.save_steps = math.ceil(max_steps * args.save_steps)

1762
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1763
1764
1765
1766
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1767
1768
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1769
1770
1771
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1772

1773
        delay_optimizer_creation = (
1774
1775
1776
1777
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1778
        )
1779
        if args.deepspeed:
1780
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
1781
1782
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
1783
1784
1785
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
1786
1787
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
1788
        elif not delay_optimizer_creation:
1789
1790
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1791
        self.state = TrainerState()
1792
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1793

1794
1795
1796
1797
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1798
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1799

1800
1801
1802
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1803
1804
1805
1806
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1807
1808
1809
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1810
1811
1812
1813
1814
        # prepare using `accelerator` prepare
        model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
            self.model, self.optimizer, self.lr_scheduler
        )

1815
1816
1817
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1818
1819
        # important: at this point:
        # self.model         is the Transformers Model
1820
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1821

Julien Chaumond's avatar
Julien Chaumond committed
1822
1823
        # Train!
        logger.info("***** Running training *****")
1824
1825
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
1826
        logger.info(f"  Instantaneous batch size per device = {self._train_batch_size:,}")
1827
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
1828
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1829
1830
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
Julien Chaumond's avatar
Julien Chaumond committed
1831

1832
        self.state.epoch = 0
1833
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1834
1835
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1836
        steps_trained_progress_bar = None
1837

Julien Chaumond's avatar
Julien Chaumond committed
1838
        # Check if continuing training from a checkpoint
1839
        if resume_from_checkpoint is not None and os.path.isfile(
1840
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1841
        ):
1842
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1843
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1844
            if not args.ignore_data_skip:
1845
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1846
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1847
1848
            else:
                steps_trained_in_current_epoch = 0
1849
1850

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1851
1852
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1853
            if not args.ignore_data_skip:
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
                if skip_first_batches is None:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
                        " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
                        " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
                        " training on data already seen by your model."
                    )
                else:
                    logger.info(
                        f"  Will skip the first {epochs_trained} epochs then the first"
                        f" {steps_trained_in_current_epoch} batches in the first epoch."
                    )
                if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
1868
1869
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1870

Sylvain Gugger's avatar
Sylvain Gugger committed
1871
1872
1873
1874
1875
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1876
1877
1878
1879
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
1880
1881
1882
1883
1884
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1885
1886
1887
1888
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1889
1890
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1891

1892
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1893
        tr_loss = torch.tensor(0.0).to(args.device)
1894
1895
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1896
        self._globalstep_last_logged = self.state.global_step
1897
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1898

1899
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1900

1901
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1902
        if not args.ignore_data_skip:
1903
            for epoch in range(epochs_trained):
1904
1905
1906
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
1907
                if is_torch_less_than_1_11 or not is_random_sampler:
1908
1909
1910
1911
1912
1913
1914
1915
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1916

1917
        total_batched_samples = 0
1918
        for epoch in range(epochs_trained, num_train_epochs):
1919
1920
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1921
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1922
                train_dataloader.dataset.set_epoch(epoch)
1923

1924
            if is_torch_tpu_available():
1925
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1926
                epoch_iterator = parallel_loader
1927
            else:
1928
                epoch_iterator = train_dataloader
1929

1930
            # Reset the past mems state at the beginning of each epoch if necessary.
1931
            if args.past_index >= 0:
1932
1933
                self._past = None

1934
            steps_in_epoch = (
1935
1936
1937
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1938
            )
1939
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1940

1941
1942
1943
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1944
            rng_to_sync = False
1945
            steps_skipped = 0
1946
1947
            if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
1948
                steps_skipped = steps_trained_in_current_epoch
1949
1950
1951
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

1952
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1953
            for step, inputs in enumerate(epoch_iterator):
1954
                total_batched_samples += 1
1955
1956
1957
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False
Julien Chaumond's avatar
Julien Chaumond committed
1958
1959
1960
1961

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1962
1963
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1964
1965
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1966
                    continue
1967
1968
1969
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1970

1971
1972
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1973

1974
                if (
1975
                    (total_batched_samples % args.gradient_accumulation_steps != 0)
1976
                    and args.parallel_mode == ParallelMode.DISTRIBUTED
1977
                    and args._no_sync_in_gradient_accumulation
Wing Lian's avatar
Wing Lian committed
1978
                    and hasattr(model, "no_sync")
1979
                ):
1980
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1981
                    with model.no_sync():
1982
                        tr_loss_step = self.training_step(model, inputs)
1983
                else:
1984
1985
                    tr_loss_step = self.training_step(model, inputs)

1986
1987
1988
1989
1990
1991
1992
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1993
1994
1995
                else:
                    tr_loss += tr_loss_step

1996
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1997

1998
1999
2000
2001
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

2002
                if total_batched_samples % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
2003
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
2004
                    steps_in_epoch <= args.gradient_accumulation_steps
2005
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
2006
                ):
2007
                    # Gradient clipping
2008
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
2009
2010
                        # deepspeed does its own clipping

2011
                        if self.do_grad_scaling:
2012
2013
2014
2015
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
2016
2017
2018
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

2019
2020
2021
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
2022
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
2023
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
2024
2025
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
2026
                            model.clip_grad_norm_(args.max_grad_norm)
2027
                        elif self.use_apex:
2028
                            # Revert to normal clipping otherwise, handling Apex or full precision
2029
                            nn.utils.clip_grad_norm_(
2030
2031
2032
2033
2034
2035
                                amp.master_params(self.optimizer),
                                args.max_grad_norm,
                            )
                        else:
                            self.accelerator.clip_grad_norm_(
                                model.parameters(),
2036
                                args.max_grad_norm,
2037
2038
2039
                            )

                    # Optimizer step
2040
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
2041
                    if self.deepspeed:
2042
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
2043
                    elif is_torch_tpu_available():
2044
2045
2046
2047
2048
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
2049
                    elif self.do_grad_scaling:
2050
                        scale_before = self.scaler.get_scale()
2051
                        self.scaler.step(self.optimizer)
2052
                        self.scaler.update()
2053
2054
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
2055
                    else:
2056
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
2057

2058
                    if optimizer_was_run and not self.deepspeed:
2059
2060
2061
                        # Delay optimizer scheduling until metrics are generated
                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                            self.lr_scheduler.step()
2062

2063
                    model.zero_grad()
2064
                    self.state.global_step += 1
2065
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2066
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
2067

2068
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
2069
2070
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2071

Sylvain Gugger's avatar
Sylvain Gugger committed
2072
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
2073
                    break
2074
2075
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2076
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
2077
2078
2079
2080
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
2081

2082
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
2083
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
2084

2085
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
2086
2087
2088
2089
2090
2091
2092
2093
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2094
            if self.control.should_training_stop:
2095
                break
Julien Chaumond's avatar
Julien Chaumond committed
2096

2097
        if args.past_index and hasattr(self, "_past"):
2098
2099
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2100
2101

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
2102
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
2103
2104
2105
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
2106
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
2107
                dist.barrier()
2108
2109
            elif is_sagemaker_mp_enabled():
                smp.barrier()
2110

2111
            self._load_best_model()
2112

2113
2114
2115
2116
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

2117
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
2118
2119
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
2120
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2121

2122
        self.is_in_train = False
2123

2124
2125
        self._memory_tracker.stop_and_update_metrics(metrics)

2126
2127
        self.log(metrics)

raghavanone's avatar
raghavanone committed
2128
2129
2130
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

2131
2132
        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
raghavanone's avatar
raghavanone committed
2133
2134
2135
2136
2137
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

2138
2139
2140
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
2141

raghavanone's avatar
raghavanone committed
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

2162
2163
2164
2165
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

2166
2167
2168
2169
2170
2171
2172
2173
        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)

        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)

        if not any(
Qingyang Wu's avatar
Qingyang Wu committed
2174
            os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
2175
        ):
2176
2177
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

2178
        logger.info(f"Loading model from {resume_from_checkpoint}.")
2179

2180
2181
        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
2182
2183
2184
2185
2186
2187
2188
2189
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

2190
        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
2191
            # If the model is on the GPU, it still works!
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
2206
                    state_dict = torch.load(weights_file, map_location="cpu")
2207
2208
2209
2210
2211
2212
2213
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
2214
2215
2216
2217
2218
                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
                else:
                    state_dict = torch.load(weights_file, map_location="cpu")

2219
2220
2221
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
2222
2223
                # release memory
                del state_dict
2224
                self._issue_warnings_after_load(load_result)
2225
2226
        else:
            # We load the sharded checkpoint
2227
2228
2229
            load_result = load_sharded_checkpoint(
                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
            )
2230
            if not is_sagemaker_mp_enabled():
2231
                self._issue_warnings_after_load(load_result)
2232
2233
2234
2235

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
2236
        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
2237
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
2238
        if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
2239
            if self.deepspeed:
2240
2241
2242
2243
2244
                if self.model_wrapped is not None:
                    # this removes the pre-hooks from the previous engine
                    self.model_wrapped.destroy()
                    self.model_wrapped = None

2245
                # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
2246
2247
2248
2249
2250
                deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
                    self,
                    num_training_steps=self.args.max_steps,
                    resume_from_checkpoint=self.state.best_model_checkpoint,
                )
2251
2252
2253
2254
2255
2256
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
                self.optimizer = optimizer
                self.lr_scheduler = lr_scheduler
            else:
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
2270
2271
2272
2273
2274
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")

2275
2276
2277
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
                else:
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
                    if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
                        # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
                        if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
                            if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
                                model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
                                # Load_adapter has no return value present, modify it when appropriate.
                                from torch.nn.modules.module import _IncompatibleKeys

                                load_result = _IncompatibleKeys([], [])
                            else:
                                logger.warning(
                                    "The intermediate checkpoints of PEFT may not be saved correctly, "
                                    "using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
                                    "here are some examples https://github.com/huggingface/peft/issues/96"
                                )
                        else:
                            # We can't do pure 8bit training using transformers.
                            logger.warning("Could not loading a quantized checkpoint.")
2296
                    else:
2297
2298
2299
2300
2301
                        # We load the model state dict on the CPU to avoid an OOM error.
                        if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
                            state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
                        else:
                            state_dict = torch.load(best_model_path, map_location="cpu")
2302

2303
2304
2305
2306
                        # If the model is on the GPU, it still works!
                        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                        # which takes *args instead of **kwargs
                        load_result = model.load_state_dict(state_dict, False)
2307
                if not is_sagemaker_mp_enabled():
2308
                    self._issue_warnings_after_load(load_result)
2309
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2310
2311
2312
2313
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2314
                self._issue_warnings_after_load(load_result)
2315
2316
2317
2318
2319
2320
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2321
    def _issue_warnings_after_load(self, load_result):
2322
        if len(load_result.missing_keys) != 0:
2323
2324
2325
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2326
2327
                self.model.tie_weights()
            else:
2328
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2329
        if len(load_result.unexpected_keys) != 0:
2330
2331
2332
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2333

2334
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2335
        if self.control.should_log:
2336
2337
2338
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2339
            logs: Dict[str, float] = {}
2340
2341
2342
2343

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2344
2345
2346
            # reset tr_loss to zero
            tr_loss -= tr_loss

2347
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2348
            logs["learning_rate"] = self._get_learning_rate()
2349

2350
            self._total_loss_scalar += tr_loss_scalar
2351
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2352
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2353
2354
2355
2356
2357

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2358
            if isinstance(self.eval_dataset, dict):
2359
                metrics = {}
2360
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
2361
                    dataset_metrics = self.evaluate(
2362
2363
2364
2365
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
2366
                    metrics.update(dataset_metrics)
2367
2368
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2369
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2370

2371
2372
2373
2374
            # Run delayed LR scheduler now that metrics are populated
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.lr_scheduler.step(metrics[self.args.metric_for_best_model])

Sylvain Gugger's avatar
Sylvain Gugger committed
2375
2376
2377
2378
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2379
2380
2381
2382
2383
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2384
2385
2386
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2387
            if not os.path.isfile(rng_file):
2388
                logger.info(
2389
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2390
2391
2392
2393
2394
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2395
            if not os.path.isfile(rng_file):
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
2407
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2408
                torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
2409
            else:
2410
                try:
2411
                    torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
2412
                except Exception as e:
2413
                    logger.info(
2414
2415
2416
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2417
2418
2419
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2420
    def _save_checkpoint(self, model, trial, metrics=None):
2421
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2422
        # want to save except FullyShardedDDP.
2423
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2424

2425
        # Save model checkpoint
2426
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2427

raghavanone's avatar
raghavanone committed
2428
        if self.hp_search_backend is None and trial is None:
2429
            self.store_flos()
2430

raghavanone's avatar
raghavanone committed
2431
        run_dir = self._get_output_dir(trial=trial)
2432
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2433
        self.save_model(output_dir, _internal_call=True)
2434
        if self.deepspeed:
2435
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2436
            # config `stage3_gather_16bit_weights_on_model_save` is True
2437
            self.deepspeed.save_checkpoint(output_dir)
2438
2439

        # Save optimizer and scheduler
2440
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2441
            self.optimizer.consolidate_state_dict()
2442

Qingyang Wu's avatar
Qingyang Wu committed
2443
2444
2445
2446
2447
2448
        if self.fsdp:
            # FSDP has a different interface for saving optimizer states.
            # Needs to be called on all ranks to gather all states.
            # full_optim_state_dict will be deprecated after Pytorch 2.2!
            full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)

2449
2450
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2451
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2452
            with warnings.catch_warnings(record=True) as caught_warnings:
2453
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2454
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2455
        elif is_sagemaker_mp_enabled():
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2471
        elif self.args.should_save and not self.deepspeed:
2472
            # deepspeed.save_checkpoint above saves model/optim/sched
Qingyang Wu's avatar
Qingyang Wu committed
2473
2474
2475
2476
2477
            if self.fsdp:
                torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
            else:
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

2478
            with warnings.catch_warnings(record=True) as caught_warnings:
2479
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2480
            reissue_pt_warnings(caught_warnings)
2481
            if self.do_grad_scaling:
2482
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2483
2484

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2485
        if metrics is not None and self.args.metric_for_best_model is not None:
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2501
        if self.args.should_save:
2502
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2503

2504
2505
2506
2507
2508
2509
2510
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
2511
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
2512
2513
2514
2515
2516
2517
2518
2519
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2520
2521
2522
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2523

2524
        if self.args.world_size <= 1:
2525
2526
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2527
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2528

2529
2530
2531
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2532
        # Maybe delete some older checkpoints.
2533
        if self.args.should_save:
2534
2535
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2536
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2537
        """If optimizer and scheduler states exist, load them."""
2538
        if checkpoint is None:
2539
2540
            return

2541
        if self.deepspeed:
2542
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2543
2544
            return

2545
2546
2547
2548
2549
2550
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2551
2552
2553
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2554
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2555
                with warnings.catch_warnings(record=True) as caught_warnings:
2556
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2557
2558
2559
2560
2561
2562
2563
2564
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
2565
                if is_sagemaker_mp_enabled():
2566
2567
2568
2569
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2570

2571
2572
2573
2574
2575
2576
2577
2578
2579
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2580
2581
2582

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
2583
2584
2585
2586
                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
Qingyang Wu's avatar
Qingyang Wu committed
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
                    if self.fsdp:
                        full_osd = None
                        # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
                        if self.args.process_index == 0:
                            full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
                        # call scatter_full_optim_state_dict on all ranks
                        sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
                        self.optimizer.load_state_dict(sharded_osd)
                    else:
                        self.optimizer.load_state_dict(
                            torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                        )
Sylvain Gugger's avatar
Sylvain Gugger committed
2599
                with warnings.catch_warnings(record=True) as caught_warnings:
2600
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2601
                reissue_pt_warnings(caught_warnings)
2602
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2603
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2604

2605
2606
2607
2608
2609
2610
2611
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2612
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2613
        **kwargs,
2614
2615
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2616
2617
2618
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2619

2620
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2621

Sylvain Gugger's avatar
Sylvain Gugger committed
2622
2623
2624
2625
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2626
2627

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2628

2629
        Args:
2630
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2631
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2632
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2633
2634
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2635
2636
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2637
            n_trials (`int`, *optional*, defaults to 100):
2638
                The number of trial runs to test.
2639
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2640
2641
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2642
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2643
2644
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2645
2646
2647
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2648
2649
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2650

Sylvain Gugger's avatar
Sylvain Gugger committed
2651
2652
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2653
2654
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2655
2656

        Returns:
2657
2658
            [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
            `run_summary` attribute for Ray backend.
2659
2660
2661
2662
2663
2664
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2665
2666
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2667
                    "To install sigopt run `pip install sigopt`."
2668
2669
2670
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2671
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2672
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2673
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2674
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2675
            )
2676
2677
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2678
2679
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2680
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2681
2682
2683
2684
2685
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2686
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2687
        self.hp_name = hp_name
2688
2689
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2690
2691
2692
2693
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2694
            HPSearchBackend.WANDB: run_hp_search_wandb,
2695
2696
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2697
2698
2699
2700

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2701
    def log(self, logs: Dict[str, float]) -> None:
2702
        """
2703
        Log `logs` on the various objects watching training.
2704
2705
2706
2707

        Subclass and override this method to inject custom behavior.

        Args:
2708
            logs (`Dict[str, float]`):
2709
2710
                The values to log.
        """
2711
        if self.state.epoch is not None:
2712
            logs["epoch"] = round(self.state.epoch, 2)
2713

2714
2715
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2716
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2717

2718
2719
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2720
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2721
        """
2722
2723
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2724
2725
2726
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
2727
            kwargs = {"device": self.args.device}
2728
2729
            if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)):
                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2730
2731
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
2732
                kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()})
2733
2734
2735
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2736
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2737
        """
2738
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2739
2740
        handling potential state.
        """
2741
        inputs = self._prepare_input(inputs)
2742
2743
2744
2745
2746
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2747
2748
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2749

2750
2751
        return inputs

2752
2753
2754
2755
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
2756
        return self.autocast_smart_context_manager()
2757

2758
    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
2759
        """
2760
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2761
2762
        arguments, depending on the situation.
        """
2763
        if self.use_cuda_amp or self.use_cpu_amp:
2764
            if is_torch_greater_or_equal_than_1_10:
2765
                ctx_manager = (
2766
                    torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2767
                    if self.use_cpu_amp
2768
                    else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
2769
                )
2770
            else:
2771
                ctx_manager = torch.cuda.amp.autocast()
2772
2773
2774
2775
2776
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2777
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2778
        """
2779
        Perform a training step on a batch of inputs.
2780
2781
2782
2783

        Subclass and override to inject custom behavior.

        Args:
2784
            model (`nn.Module`):
2785
                The model to train.
2786
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2787
2788
2789
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2790
                argument `labels`. Check your model's documentation for all accepted arguments.
2791
2792

        Return:
2793
            `torch.Tensor`: The tensor with training loss on this batch.
2794
2795
        """
        model.train()
2796
        inputs = self._prepare_inputs(inputs)
2797

Sylvain Gugger's avatar
Sylvain Gugger committed
2798
        if is_sagemaker_mp_enabled():
2799
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2800
2801
            return loss_mb.reduce_mean().detach().to(self.args.device)

2802
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2803
            loss = self.compute_loss(model, inputs)
2804

Julien Chaumond's avatar
Julien Chaumond committed
2805
2806
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2807

2808
2809
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
2810
2811
            loss = loss / self.args.gradient_accumulation_steps

2812
        if self.do_grad_scaling:
2813
            self.scaler.scale(loss).backward()
2814
        elif self.use_apex:
2815
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2816
                scaled_loss.backward()
2817
        elif self.deepspeed:
2818
2819
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2820
        else:
2821
            self.accelerator.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2822

2823
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
2824

2825
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2826
2827
2828
2829
2830
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2831
2832
2833
2834
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2835
2836
        outputs = model(**inputs)
        # Save past state if it exists
2837
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2838
2839
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2840

2841
        if labels is not None:
2842
2843
2844
2845
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2846
        else:
2847
2848
2849
2850
2851
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2852
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2853
2854
2855
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2856

2857
2858
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2859
2860
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2861
        """
2862
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2863

2864
2865
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2866
        Whether or not this process is the global main process (when training in a distributed fashion on several
2867
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2868
        """
2869
2870
2871
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2872
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2873
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2874
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2875

Sylvain Gugger's avatar
Sylvain Gugger committed
2876
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2877
        """
2878
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2879

2880
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2881
        """
2882
2883
2884
2885

        if output_dir is None:
            output_dir = self.args.output_dir

2886
        if is_torch_tpu_available():
2887
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2888
2889
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2890
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2891
            state_dict = self.model_wrapped.state_dict()
2892
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2893
                self._save(output_dir, state_dict=state_dict)
2894
2895
2896
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2897
        elif (
2898
2899
2900
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2901
2902
        ):
            state_dict = self.model.state_dict()
2903

2904
            if self.args.should_save:
2905
                self._save(output_dir, state_dict=state_dict)
2906
2907
        elif self.deepspeed:
            # this takes care of everything as long as we aren't under zero3
2908
            if self.args.should_save:
2909
2910
2911
2912
2913
2914
2915
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2916
                if self.args.should_save:
2917
2918
2919
2920
2921
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2922
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2923
2924
                # if false it will not be saved.
                # This must be called on all ranks
2925
                if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
2926
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2927
2928
2929
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2930
2931
                    )
                    self.deepspeed.save_checkpoint(output_dir)
2932

2933
        elif self.args.should_save:
2934
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2935

Sylvain Gugger's avatar
Sylvain Gugger committed
2936
2937
2938
2939
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2940
2941
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2942
        logger.info(f"Saving model checkpoint to {output_dir}")
2943
2944
2945

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2946
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2947
2948
2949
2950

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2951
        if not isinstance(self.model, PreTrainedModel):
2952
2953
2954
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2955
                    is_main_process=self.args.should_save,
2956
2957
2958
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2959
2960
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2961
2962
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2963
        else:
2964
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2965
        if self.tokenizer is not None and self.args.should_save:
2966
            self.tokenizer.save_pretrained(output_dir)
2967

2968
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2969
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2970
2971
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2972
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2973
2974
2975
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2976
2977
2978
            if state_dict is None:
                state_dict = self.model.state_dict()

2979
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2980
2981
2982
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
2983
2984
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2985
2986
2987
2988
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2989
        else:
2990
2991
2992
2993
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

2994
        if self.tokenizer is not None:
2995
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2996
2997

        # Good practice: save your training arguments together with the trained model
2998
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2999

3000
    def store_flos(self):
3001
        # Storing the number of floating-point operations that went into the model
3002
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3003
3004
3005
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
3006
3007
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
3008
            self.state.total_flos += self.current_flos
3009
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
3010

3011
3012
3013
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
3014
3015
        ordering_and_checkpoint_path = []

3016
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
3017
3018
3019
3020
3021

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
3022
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
3023
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
3024
3025
3026
3027
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
3028
3029
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
3030
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
3031
3032
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
3033
3034
        return checkpoints_sorted

3035
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
3036
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
3037
3038
3039
            return

        # Check if we should delete older checkpoint(s)
3040
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
3041
3042
3043
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

3044
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
3055
3056
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
3057
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
3058
            shutil.rmtree(checkpoint, ignore_errors=True)
Julien Chaumond's avatar
Julien Chaumond committed
3059

3060
    def evaluate(
3061
3062
3063
3064
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3065
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
3066
        """
3067
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
3068

Sylvain Gugger's avatar
Sylvain Gugger committed
3069
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
3070
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
3071

3072
3073
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
3074
        Args:
3075
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
3076
3077
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
3078
                method.
3079
            ignore_keys (`List[str]`, *optional*):
3080
3081
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3082
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
3083
3084
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
3085

Julien Chaumond's avatar
Julien Chaumond committed
3086
        Returns:
3087
3088
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
3089
        """
3090
3091
3092
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3093
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
3094
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
3095

3096
3097
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3098
3099
3100
3101
3102
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
3103
            ignore_keys=ignore_keys,
3104
            metric_key_prefix=metric_key_prefix,
3105
        )
Lysandre Debut's avatar
Lysandre Debut committed
3106

3107
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3108
3109
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3110
3111
3112
3113
3114
3115
3116
3117
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3118

3119
        self.log(output.metrics)
3120

3121
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
3122
3123
3124
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
3125
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
3126
3127
3128

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
3129
3130
        return output.metrics

3131
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
3132
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
3133
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3134
        """
3135
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
3136

Sylvain Gugger's avatar
Sylvain Gugger committed
3137
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
3138
        will also return metrics, like in `evaluate()`.
3139
3140

        Args:
3141
3142
3143
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
3144
            ignore_keys (`List[str]`, *optional*):
3145
3146
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3147
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
3148
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
3149
                "test_bleu" if the prefix is "test" (default)
3150

3151
3152
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
3153
3154
3155
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
3156

3157
        </Tip>
3158

3159
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
3160

3161
3162
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
3163
3164
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
3165
        """
3166
3167
3168
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
3169
        test_dataloader = self.get_test_dataloader(test_dataset)
3170
        start_time = time.time()
3171

3172
3173
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
3174
3175
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
3176
        total_batch_size = self.args.eval_batch_size * self.args.world_size
3177
3178
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
3179
3180
3181
3182
3183
3184
3185
3186
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
3187

3188
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
3189
3190
        self._memory_tracker.stop_and_update_metrics(output.metrics)

3191
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
3192

3193
    def evaluation_loop(
3194
3195
3196
3197
3198
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
3199
        metric_key_prefix: str = "eval",
3200
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
3201
        """
3202
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
3203
3204
3205

        Works both with or without labels.
        """
3206
3207
3208
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
3209

3210
        # if eval is called w/o train init deepspeed here
3211
        if args.deepspeed and not self.deepspeed:
3212
3213
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
3214
3215
3216
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
3217
3218
3219
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
3220

3221
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
3222

3223
3224
3225
3226
3227
3228
3229
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3230

3231
        batch_size = self.args.eval_batch_size
3232

3233
        logger.info(f"***** Running {description} *****")
3234
        if has_length(dataloader):
3235
3236
3237
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
3238
        logger.info(f"  Batch size = {batch_size}")
3239

Julien Chaumond's avatar
Julien Chaumond committed
3240
3241
        model.eval()

3242
3243
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
3244
        eval_dataset = getattr(dataloader, "dataset", None)
3245

3246
        if is_torch_tpu_available():
3247
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3248

3249
        if args.past_index >= 0:
3250
            self._past = None
3251

3252
3253
3254
3255
3256
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
3257
3258
        inputs_host = None

3259
3260
3261
3262
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
3263
        all_inputs = None
3264
3265
3266
3267
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
3268
        for step, inputs in enumerate(dataloader):
3269
3270
3271
3272
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
3273
3274
3275
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
3276
3277

            # Prediction step
3278
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3279
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3280

3281
3282
3283
            if is_torch_tpu_available():
                xm.mark_step()

3284
            # Update containers on host
3285
            if loss is not None:
3286
                losses = self._nested_gather(loss.repeat(batch_size))
3287
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
3288
            if labels is not None:
3289
                labels = self._pad_across_processes(labels)
3290
3291
3292
3293
3294
3295
3296
3297
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3298
3299
3300
3301
            if logits is not None:
                logits = self._pad_across_processes(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
3302
                logits = self._nested_gather(logits)
3303
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
3304
3305
3306
            if labels is not None:
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3307
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
3308

3309
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3310
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3311
3312
3313
3314
3315
3316
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3317
3318
3319
3320
3321
3322
3323
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3324
3325
3326
3327
3328
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3329
3330

                # Set back to None to begin a new accumulation
3331
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3332

3333
        if args.past_index and hasattr(self, "_past"):
3334
3335
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3336

3337
        # Gather all remaining tensors and put them back on the CPU
3338
3339
3340
3341
3342
3343
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3344
3345
3346
3347
3348
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
3349
3350
3351
3352
3353
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
3354
        if has_length(eval_dataset):
3355
            num_samples = len(eval_dataset)
3356
3357
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3358
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3359
3360
            num_samples = eval_dataset.num_examples
        else:
3361
3362
3363
3364
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3365
3366
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3367
3368
3369
3370
3371
3372
3373
3374
3375

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
3376
3377
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
3378
3379
3380

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3381
3382
3383
3384
3385
3386
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3387
3388
        else:
            metrics = {}
3389

3390
3391
3392
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3393
3394
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3395
3396
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
3397

3398
        # Prefix all keys with metric_key_prefix + '_'
3399
        for key in list(metrics.keys()):
3400
3401
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3402

3403
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3404

3405
    def _nested_gather(self, tensors, name=None):
3406
3407
3408
3409
3410
3411
3412
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3413
3414
            if name is None:
                name = "nested_gather"
3415
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3416
3417
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
Zachary Mueller's avatar
Zachary Mueller committed
3418
3419
3420
        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
            self.args.distributed_state is None and self.local_rank != -1
        ):
3421
            tensors = distributed_concat(tensors)
3422
        return tensors
3423

3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
3446
3447
3448
        # When extracting XLA graphs for compilation, max_size is 0,
        # so use inequality to avoid errors.
        if tensor.shape[1] >= max_size:
3449
3450
3451
3452
3453
3454
3455
3456
3457
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3458

3459
    def prediction_step(
3460
3461
3462
3463
3464
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3465
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3466
        """
Stas Bekman's avatar
Stas Bekman committed
3467
        Perform an evaluation step on `model` using `inputs`.
3468
3469
3470
3471

        Subclass and override to inject custom behavior.

        Args:
3472
            model (`nn.Module`):
3473
                The model to evaluate.
3474
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3475
3476
3477
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3478
3479
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3480
                Whether or not to return the loss only.
3481
            ignore_keys (`List[str]`, *optional*):
3482
3483
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3484
3485

        Return:
3486
3487
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3488
        """
3489
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
3490
3491
3492
3493
3494
3495
3496
3497
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

3498
        inputs = self._prepare_inputs(inputs)
3499
3500
3501
3502
3503
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3504

3505
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
3506
        if has_labels or loss_without_labels:
3507
3508
3509
3510
3511
3512
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3513
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3514
3515
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
3516
                if has_labels or loss_without_labels:
Sylvain Gugger's avatar
Sylvain Gugger committed
3517
3518
3519
3520
3521
3522
3523
3524
3525
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3526
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3527
3528
3529
3530
3531
3532
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3533
            else:
3534
                if has_labels or loss_without_labels:
3535
                    with self.compute_loss_context_manager():
3536
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3537
                    loss = loss.mean().detach()
3538

Sylvain Gugger's avatar
Sylvain Gugger committed
3539
3540
3541
3542
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3543
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3544
                    loss = None
3545
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3546
3547
3548
3549
3550
3551
3552
3553
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3554
3555
3556
3557

        if prediction_loss_only:
            return (loss, None, None)

3558
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3559
3560
3561
3562
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3563
3564
3565

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3566
3567
3568
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3569
3570

        Args:
3571
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3572
3573
3574
                The inputs and targets of the model.

        Returns:
3575
            `int`: The number of floating-point operations.
3576
        """
3577
3578
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3579
3580
        else:
            return 0
3581

3582
    def init_git_repo(self, at_init: bool = False):
3583
        """
3584
        Initializes a git repo in `self.args.hub_model_id`.
3585
3586
3587
3588
3589
3590

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3591
        """
3592
        if not self.is_world_process_zero():
3593
            return
3594
        if self.args.hub_model_id is None:
3595
            repo_name = Path(self.args.output_dir).absolute().name
3596
3597
        else:
            repo_name = self.args.hub_model_id
3598
3599
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3600

3601
3602
        # Make sure the repo exists.
        create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
3603
        try:
3604
            self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3605
        except EnvironmentError:
3606
            if self.args.overwrite_output_dir and at_init:
3607
3608
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
3609
                self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
3610
3611
3612
3613
            else:
                raise

        self.repo.git_pull()
3614
3615

        # By default, ignore the checkpoint folders
3616
3617
3618
3619
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3620
3621
3622
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3623
3624
3625
3626
        # Add "*.sagemaker" to .gitignore if using SageMaker
        if os.environ.get("SM_TRAINING_ENV"):
            self._add_sm_patterns_to_gitignore()

3627
3628
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3629
3630
3631
3632
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3633
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3634
3635
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3636
3637
3638
3639
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3640
    ):
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3666
3667
3668
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3669
3670
3671
3672
3673
3674
3675
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3676
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3677
3678
3679
3680
3681
3682
3683
3684
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
3695
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3719
3720
3721
3722
            push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
            # Return type of `Repository.push_to_hub` is either None or a tuple.
            if push_work is not None:
                self.push_in_progress = push_work[1]
3723
3724
        except Exception as e:
            logger.error(f"Error when pushing to hub: {e}")
3725
3726
3727
3728
3729
3730
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3731
        """
3732
        Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3733
3734

        Parameters:
3735
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3736
                Message to commit while pushing.
3737
3738
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3739
            kwargs:
3740
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3741
3742

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3743
3744
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3745
        """
3746
3747
3748
3749
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3750

3751
3752
        model_name = kwargs.pop("model_name", None)
        if model_name is None and self.args.should_save:
3753
3754
3755
3756
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3757

3758
3759
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3760
        self.save_model(_internal_call=True)
3761
3762
3763
3764
3765

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3766
3767
3768
3769
3770
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3771
3772
3773
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3774
3775
3776
3777
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3778
3779
3780
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3781
3782
3783
3784
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3785

3786
3787
3788
3789
3790
3791
3792
3793
3794
3795
3796
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
3797
    ) -> EvalLoopOutput:
3798
        """
3799
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3800
3801
3802

        Works both with or without labels.
        """
3803
3804
        args = self.args

3805
3806
3807
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3808
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3809
3810

        # if eval is called w/o train init deepspeed here
3811
        if args.deepspeed and not self.deepspeed:
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

3824
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3825

3826
3827
3828
3829
3830
3831
3832
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3833
3834
3835
3836
3837
3838
3839
3840
3841

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3842
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3843

3844
        world_size = max(1, args.world_size)
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3855
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3856
3857
3858
3859

        model.eval()

        if is_torch_tpu_available():
3860
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3861

3862
        if args.past_index >= 0:
3863
3864
3865
3866
3867
3868
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3869
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3870

3871
3872
3873
3874
3875
3876
3877
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3878
3879
3880
3881
3882
3883
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3884
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3885
3886

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3887
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3888
3889
3890
3891
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3892
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3893
3894

                # Set back to None to begin a new accumulation
3895
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3896

3897
        if args.past_index and hasattr(self, "_past"):
3898
3899
3900
3901
3902
3903
3904
3905
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3906
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3907
3908
3909
3910

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3911
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3912
3913

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3914
3915
3916
3917
3918
3919
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3920
3921
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
3932
3933
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

3934
        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
3935
3936
3937
3938
3939
3940
3941
3942
3943
3944
3945
3946

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3947
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3948
3949
3950
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
3951
3952
3953
3954
3955
3956
3957
3958
3959
3960
3961
3962
3963
3964
3965
3966
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988
3989

    def _add_sm_patterns_to_gitignore(self) -> None:
        """Add SageMaker Checkpointing patterns to .gitignore file."""
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return

        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Get current .gitignore content
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Add the patterns to .gitignore
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write the .gitignore file if it has changed
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        self.repo.git_add(".gitignore")

        # avoid race condition with git status
        time.sleep(0.5)

        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()