trainer.py 169 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 馃 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import functools
21
import glob
22
import inspect
23
import math
Julien Chaumond's avatar
Julien Chaumond committed
24
import os
25
import random
Julien Chaumond's avatar
Julien Chaumond committed
26
27
import re
import shutil
28
import sys
29
import time
30
import warnings
31
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
32
from pathlib import Path
33
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
34

35
36
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
37

38
39
# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
40
    default_hp_search_backend,
41
    get_reporting_integration_callbacks,
42
    hp_params,
43
    is_fairscale_available,
44
    is_optuna_available,
45
    is_ray_tune_available,
46
    is_sigopt_available,
47
    is_wandb_available,
48
49
    run_hp_search_optuna,
    run_hp_search_ray,
50
    run_hp_search_sigopt,
51
    run_hp_search_wandb,
52
)
53
54
55

import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
56
import torch.distributed as dist
57
58
from packaging import version
from torch import nn
59
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
60
61
from torch.utils.data.distributed import DistributedSampler

62
63
from huggingface_hub import Repository

64
65
from . import __version__
from .configuration_utils import PretrainedConfig
66
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
67
from .debug_utils import DebugOption, DebugUnderflowOverflow
68
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
69
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
70
from .modelcard import TrainingSummary
71
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
72
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
73
from .optimization import Adafactor, get_scheduler
74
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
75
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
76
77
78
79
80
81
82
83
84
85
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
86
    DistributedLengthGroupedSampler,
87
    DistributedSamplerWithLoop,
88
    DistributedTensorGatherer,
89
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
90
    LabelSmoother,
91
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
92
    SequentialDistributedSampler,
93
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
    distributed_broadcast_scalars,
    distributed_concat,
96
    find_batch_size,
97
    get_module_class_from_name,
98
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
101
    nested_concat,
    nested_detach,
    nested_numpify,
102
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
103
104
105
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
106
107
108
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
109
    EvalLoopOutput,
110
    EvalPrediction,
111
    FSDPOption,
112
    HPSearchBackend,
113
114
    HubStrategy,
    IntervalStrategy,
115
    PredictionOutput,
116
    RemoveColumnsCollator,
117
    ShardedDDPOption,
118
    TrainerMemoryTracker,
119
120
121
    TrainOutput,
    default_compute_objective,
    default_hp_space,
122
    denumpify_detensorize,
123
    enable_full_determinism,
124
    find_executable_batch_size,
125
    get_last_checkpoint,
126
    has_length,
127
    number_of_arguments,
128
    seed_worker,
129
    set_seed,
130
    speed_metrics,
131
)
132
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
133
134
from .utils import (
    CONFIG_NAME,
135
    WEIGHTS_INDEX_NAME,
136
    WEIGHTS_NAME,
137
    find_labels,
138
139
140
141
    get_full_repo_name,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
142
    is_ipex_available,
143
144
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
145
    is_torch_tensorrt_fx_available,
146
    is_torch_tpu_available,
147
    is_torchdynamo_available,
148
149
    logging,
)
150
from .utils.generic import ContextManagers
Julien Chaumond's avatar
Julien Chaumond committed
151
152


153
_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
154

Sylvain Gugger's avatar
Sylvain Gugger committed
155
DEFAULT_CALLBACKS = [DefaultFlowCallback]
156
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
157

158
159
160
161
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
162

163
164
if is_apex_available():
    from apex import amp
165

166
167
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
168

169
if is_torch_tpu_available(check_device=False):
Lysandre Debut's avatar
Lysandre Debut committed
170
171
172
173
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

174
if is_fairscale_available():
175
    dep_version_check("fairscale")
176
    import fairscale
177
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
178
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
179
    from fairscale.nn.wrap import auto_wrap
180
181
182
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

183

Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
186
187
188
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
Sylvain Gugger's avatar
Sylvain Gugger committed
189
190

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
191
192
else:
    IS_SAGEMAKER_MP_POST_1_10 = False
Sylvain Gugger's avatar
Sylvain Gugger committed
193

194

195
196
197
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
198
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
199
200


201
202
203
204
205
206
207
208
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
209
210
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
211
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 馃 Transformers.
212
213

    Args:
214
215
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
216

217
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
218

Sylvain Gugger's avatar
Sylvain Gugger committed
219
220
221
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 馃 Transformers
            models.
222
223
224
225

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
226
227
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
228
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
231
232
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
233
            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
234
235
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
236
237
238
239
240
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
241
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
242
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
243
244
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
245
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
246
247
248
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
249
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
250
251
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
252

253
254
255
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
256
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
257
258
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
259
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
260
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
261
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
262

263
264
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
265
266
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
267
268
269
270
271
272
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
273

274
275
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
276
277
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
278
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
279
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
280
281
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
282
283
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
284
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
285
286
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
287
288
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
289

Julien Chaumond's avatar
Julien Chaumond committed
290
291
    """

292
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
293

Julien Chaumond's avatar
Julien Chaumond committed
294
295
    def __init__(
        self,
296
        model: Union[PreTrainedModel, nn.Module] = None,
297
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
298
299
300
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
301
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
302
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
303
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
304
        callbacks: Optional[List[TrainerCallback]] = None,
305
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
306
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
Julien Chaumond's avatar
Julien Chaumond committed
307
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
308
        if args is None:
309
310
311
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
312
313
        self.args = args
        # Seed must be set before instantiating the model when using model
314
        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
315
        self.hp_name = None
316
        self.deepspeed = None
317
        self.is_in_train = False
318

319
320
321
322
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

323
        # set the correct log level depending on the node
324
        log_level = args.get_process_log_level()
325
326
        logging.set_verbosity(log_level)

327
328
329
        # force device and distributed setup init explicitly
        args._setup_devices

330
331
332
333
334
335
336
337
338
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
Sylvain Gugger's avatar
Sylvain Gugger committed
339
340
341
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
                    " release.",
342
343
344
                    FutureWarning,
                )
            self.model_init = model_init
345

346
347
348
349
350
351
352
353
        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
            raise ValueError(
                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
                "computes hidden states and does not accept any labels. You should choose a model with a head "
                "suitable for your task like any of the `AutoModelForXxx` listed at "
                "https://huggingface.co/docs/transformers/model_doc/auto."
            )

354
355
356
357
358
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

359
360
361
362
363
364
365
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
366
367
368
369
            if len(args.fsdp) > 0:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
                )
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386

            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

387
388
389
390
391
392
393
394
395
        self.fsdp = None
        if len(args.fsdp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )
            if args.local_rank == -1:
                raise ValueError("Using fsdp only works in distributed training.")

396
397
398
            # dep_version_check("torch>=1.12.0")
            # Would have to update setup.py with torch>=1.12.0
            # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
399
            # below is the current alternative.
400
            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
401
                raise ValueError("FSDP requires PyTorch >= 1.12.0")
402
403
404
405
406
407
408

            from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy

            if FSDPOption.FULL_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.FULL_SHARD
            elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
                self.fsdp = ShardingStrategy.SHARD_GRAD_OP
409
410
            elif FSDPOption.NO_SHARD in args.fsdp:
                self.fsdp = ShardingStrategy.NO_SHARD
411

412
        # one place to sort out whether to place the model on device or not
413
414
415
416
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
417
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
418
        # 4. Sharded DDP - same as MP
419
        # 5. FSDP - same as MP
420
        self.place_model_on_device = args.place_model_on_device
421
422
        if (
            self.is_model_parallel
423
            or args.deepspeed
424
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
425
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
426
            or (self.fsdp is not None)
427
        ):
428
429
            self.place_model_on_device = False

430
431
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
432
433
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
434
        self.tokenizer = tokenizer
435

436
        if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
437
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
438
439
440

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
441
            self.args._n_gpu = 1
442
443
444
445
446

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
447
        self.compute_metrics = compute_metrics
448
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
449
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
450
451
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
452
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
453
454
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        if is_torch_tpu_available() and self.optimizer is not None:
            for param in self.model.parameters():
                model_device = param.device
                break
            for param_group in self.optimizer.param_groups:
                if len(param_group["params"]) > 0:
                    optimizer_device = param_group["params"][0].device
                    break
            if model_device != optimizer_device:
                raise ValueError(
                    "The model and the optimizer parameters are not on the same device, which probably means you"
                    " created an optimizer around your model **before** putting on the device and passing it to the"
                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
                )
470
        if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
471
472
473
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
474
                "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
475
476
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
477
478
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
479
480
481
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
482
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
483

484
485
486
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

487
488
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
489
            self.init_git_repo(at_init=True)
490
491
492
493
494
495
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
            elif args.local_rank != -1:
                dist.barrier()

496
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
497
            os.makedirs(self.args.output_dir, exist_ok=True)
498

499
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
500
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
501

502
503
504
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

505
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
506
507
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")

508
509
510
511
512
513
514
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

515
        self._signature_columns = None
516

517
518
        # Mixed precision setup
        self.use_apex = False
519
520
        self.use_cuda_amp = False
        self.use_cpu_amp = False
521

522
523
524
525
526
        # Mixed precision setup for SageMaker Model Parallel
        if is_sagemaker_mp_enabled():
            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
            if args.bf16:
                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

            if IS_SAGEMAKER_MP_POST_1_10:
                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
                if args.fp16 != smp.state.cfg.fp16:
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
                        f"but FP16 provided in trainer argument is {args.fp16},"
                        f"setting to {smp.state.cfg.fp16}"
                    )
                    args.fp16 = smp.state.cfg.fp16
            else:
                # smp < 1.10 does not support fp16 in trainer.
                if hasattr(smp.state.cfg, "fp16"):
                    logger.warning(
                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
                    )
544

545
546
        if args.fp16 or args.bf16:
            if args.half_precision_backend == "auto":
547
548
549
550
551
552
553
                if args.device == torch.device("cpu"):
                    if args.fp16:
                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
                    elif _is_native_cpu_amp_available:
                        args.half_precision_backend = "cpu_amp"
                    else:
                        raise ValueError("Tried to use cpu amp but native cpu amp is not available")
554
                else:
555
                    args.half_precision_backend = "cuda_amp"
556

557
            logger.info(f"Using {args.half_precision_backend} half precision backend")
558

559
        self.do_grad_scaling = False
560
561
        if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
            # deepspeed and SageMaker Model Parallel manage their own half precision
562
563
            if args.half_precision_backend == "cuda_amp":
                self.use_cuda_amp = True
564
565
                self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                self.do_grad_scaling = True
566
                if self.sharded_ddp is not None:
567
                    self.scaler = ShardedGradScaler()
568
569
                elif self.fsdp is not None:
                    if self.amp_dtype == torch.float16:
570
571
572
                        from torch.distributed.fsdp.sharded_grad_scaler import (
                            ShardedGradScaler as FSDPShardedGradScaler,
                        )
573

574
                        self.scaler = FSDPShardedGradScaler()
575
576
577
578
579
                    else:
                        self.do_grad_scaling = False
                        self.use_cuda_amp = False
                        self.amp_dtype = None

580
581
582
583
                elif is_torch_tpu_available():
                    from torch_xla.amp import GradScaler

                    self.scaler = GradScaler()
584
585
                else:
                    self.scaler = torch.cuda.amp.GradScaler()
586
587
588
            elif args.half_precision_backend == "cpu_amp":
                self.use_cpu_amp = True
                self.amp_dtype = torch.bfloat16
589
590
591
            else:
                if not is_apex_available():
                    raise ImportError(
Sylvain Gugger's avatar
Sylvain Gugger committed
592
593
                        "Using FP16 with APEX but APEX is not installed, please refer to"
                        " https://www.github.com/nvidia/apex."
594
595
596
                    )
                self.use_apex = True

597
598
599
600
601
602
603
604
605
606
607
608
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if (
            is_sagemaker_mp_enabled()
            and self.use_cuda_amp
            and args.max_grad_norm is not None
            and args.max_grad_norm > 0
        ):
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
609
610
611
612
613
614
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

615
616
617
618
619
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
620
        self.control = TrainerControl()
621
622
623
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
624
        self.hp_search_backend = None
625
        self.use_tune_checkpoints = False
626
        default_label_names = find_labels(self.model.__class__)
627
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
628
629
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

630
631
632
        # Internal variables to keep track of the original batch size
        self._train_batch_size = args.train_batch_size

633
634
635
        # very last
        self._memory_tracker.stop_and_update_metrics()

636
637
638
639
640
641
642
643
644
645
646
647
        # torchdynamo
        if args.torchdynamo:
            if not is_torchdynamo_available():
                raise RuntimeError("Torchdynamo is not installed.")
            import torchdynamo
            from torchdynamo.optimizations import backends

            def get_ctx():
                # Normal
                if args.torchdynamo == "eager":
                    return torchdynamo.optimize("eager")
                elif args.torchdynamo == "nvfuser":
Yih-Dar's avatar
Yih-Dar committed
648
                    return torchdynamo.optimize("aot_nvfuser")
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
                # TensorRT
                if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
                    if not is_torch_tensorrt_fx_available():
                        raise RuntimeError("Torch-TensorRT FX path is not installed.")
                    if args.torchdynamo == "fx2trt-fp16":
                        return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
                    elif args.torchdynamo == "fx2trt":
                        return torchdynamo.optimize(backends.fx2trt_compiler)
                else:
                    raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.")

            self.ctx_manager_torchdynamo = get_ctx()
        else:
            self.ctx_manager_torchdynamo = contextlib.nullcontext()

Sylvain Gugger's avatar
Sylvain Gugger committed
664
665
    def add_callback(self, callback):
        """
666
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
667
668

        Args:
669
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
670
671
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
672
673
674
675
676
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
677
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
678

679
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
680
681

        Args:
682
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
683
684
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
685
686

        Returns:
687
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
688
689
690
691
692
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
693
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
694
695

        Args:
696
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
697
698
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
699
700
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
701

Sylvain Gugger's avatar
Sylvain Gugger committed
702
703
704
705
706
707
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

708
    def _set_signature_columns_if_needed(self):
709
710
711
712
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
713
714
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
715

716
717
718
719
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
720
        signature_columns = self._signature_columns
721
722

        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
723
        if len(ignored_columns) > 0:
724
            dset_description = "" if description is None else f"in the {description} set"
725
726
727
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
728
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
729
                " you can safely ignore this message."
730
            )
731

732
        columns = [k for k in signature_columns if k in dataset.column_names]
733

734
735
736
737
738
739
740
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
741

742
743
744
745
746
747
748
    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
749
        signature_columns = self._signature_columns
750
751
752
753
754
755
756
757
758
759

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

760
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
761
        if self.train_dataset is None or not has_length(self.train_dataset):
762
            return None
763

764
        generator = None
765
        if self.args.world_size <= 1:
766
            generator = torch.Generator()
767
768
769
770
771
772
773
774
775
776
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
777

778
779
        # Build the sampler.
        if self.args.group_by_length:
780
781
782
783
784
785
786
787
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
788
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
789
            if self.args.world_size <= 1:
790
                return LengthGroupedSampler(
791
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
792
                    dataset=self.train_dataset,
793
794
795
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
796
                )
797
798
            else:
                return DistributedLengthGroupedSampler(
799
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
800
                    dataset=self.train_dataset,
801
802
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
803
                    lengths=lengths,
804
                    model_input_name=model_input_name,
805
                    seed=seed,
806
807
808
                )

        else:
809
            if self.args.world_size <= 1:
810
                return RandomSampler(self.train_dataset, generator=generator)
Sylvain Gugger's avatar
Sylvain Gugger committed
811
812
813
814
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
815
816
817
818
819
820
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
821
                    seed=seed,
822
                )
823
            else:
824
                return DistributedSampler(
825
826
827
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
828
                    seed=seed,
829
                )
830
831
832

    def get_train_dataloader(self) -> DataLoader:
        """
833
        Returns the training [`~torch.utils.data.DataLoader`].
834

835
836
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
837
838
839
840
841

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
842

843
        train_dataset = self.train_dataset
844
        data_collator = self.data_collator
845
846
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
847
848
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
849

850
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
851
852
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
853
                    train_dataset,
854
                    batch_size=self._train_batch_size,
855
856
857
858
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
859

860
861
            return DataLoader(
                train_dataset,
862
                batch_size=self.args.per_device_train_batch_size,
863
                collate_fn=data_collator,
864
865
866
867
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

868
869
870
        train_sampler = self._get_train_sampler()

        return DataLoader(
871
            train_dataset,
872
            batch_size=self._train_batch_size,
Julien Chaumond's avatar
Julien Chaumond committed
873
            sampler=train_sampler,
874
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
875
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
876
            num_workers=self.args.dataloader_num_workers,
877
            pin_memory=self.args.dataloader_pin_memory,
878
            worker_init_fn=seed_worker,
Julien Chaumond's avatar
Julien Chaumond committed
879
880
        )

881
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            elif self.args.local_rank != -1:
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
904
905
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
906
907
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
908
            )
Lysandre Debut's avatar
Lysandre Debut committed
909

Julien Chaumond's avatar
Julien Chaumond committed
910
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
911
        """
912
        Returns the evaluation [`~torch.utils.data.DataLoader`].
913

914
915
        Subclass and override this method if you want to inject some custom behavior.

916
        Args:
917
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
918
919
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
920
        """
Julien Chaumond's avatar
Julien Chaumond committed
921
922
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
923
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
924
        data_collator = self.data_collator
925

926
927
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
928
929
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
930

931
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
932
933
934
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
935
                    batch_size=self.args.per_device_eval_batch_size,
936
937
938
939
940
941
942
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
943
                collate_fn=data_collator,
944
945
946
947
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

948
        eval_sampler = self._get_eval_sampler(eval_dataset)
949

950
        return DataLoader(
951
            eval_dataset,
952
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
953
            batch_size=self.args.eval_batch_size,
954
            collate_fn=data_collator,
Setu Shah's avatar
Setu Shah committed
955
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
956
            num_workers=self.args.dataloader_num_workers,
957
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
958
959
960
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
961
        """
962
        Returns the test [`~torch.utils.data.DataLoader`].
963

964
965
        Subclass and override this method if you want to inject some custom behavior.

966
        Args:
967
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
968
969
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
970
        """
971
972
        data_collator = self.data_collator

973
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
974
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
975
976
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
977

978
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
979
980
981
982
983
984
985
986
987
988
989
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
990
                collate_fn=data_collator,
991
992
993
994
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

995
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
996

997
998
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
999
            test_dataset,
1000
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
1001
            batch_size=self.args.eval_batch_size,
1002
            collate_fn=data_collator,
1003
            drop_last=self.args.dataloader_drop_last,
1004
            num_workers=self.args.dataloader_num_workers,
1005
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
1006
        )
Lysandre Debut's avatar
Lysandre Debut committed
1007

1008
    def create_optimizer_and_scheduler(self, num_training_steps: int):
1009
1010
1011
        """
        Setup the optimizer and the learning rate scheduler.

1012
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
1013
1014
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
1015
1016
        """
        self.create_optimizer()
1017
1018
1019
1020
1021
1022
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
1023
1024
1025
1026
1027

    def create_optimizer(self):
        """
        Setup the optimizer.

1028
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
1029
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1030
        """
1031
1032
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

1033
        if self.optimizer is None:
1034
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
1035
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
1036
1037
            optimizer_grouped_parameters = [
                {
1038
                    "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters],
1039
1040
1041
                    "weight_decay": self.args.weight_decay,
                },
                {
1042
                    "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters],
1043
1044
1045
                    "weight_decay": 0.0,
                },
            ]
1046
1047
1048

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

1049
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1050
1051
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
1052
1053
                    optim=optimizer_cls,
                    **optimizer_kwargs,
1054
1055
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1056
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1057
1058
1059
1060
1061
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

1062
                    for module in opt_model.modules():
1063
1064
1065
                        if isinstance(module, nn.Embedding):
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
Sylvain Gugger's avatar
Sylvain Gugger committed
1066

Sylvain Gugger's avatar
Sylvain Gugger committed
1067
1068
1069
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

1070
1071
        return self.optimizer

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
        optimizer_kwargs = {"lr": args.learning_rate}
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
        elif args.optim == OptimizerNames.ADAMW_TORCH:
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
1100
1101
1102
1103
1104
1105
1106
1107
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
1108
1109
1110
1111
1112
1113
1114
1115
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
1116
1117
1118
1119
1120
1121
1122
1123
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
1124
1125
1126
1127
        elif args.optim == OptimizerNames.SGD:
            optimizer_cls = torch.optim.SGD
        elif args.optim == OptimizerNames.ADAGRAD:
            optimizer_cls = torch.optim.Adagrad
1128
1129
1130
1131
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

1132
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
1133
        """
1134
1135
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
1136
1137
1138
1139

        Args:
            num_training_steps (int): The number of training steps to do.
        """
1140
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1141
1142
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
1143
                optimizer=self.optimizer if optimizer is None else optimizer,
1144
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
1145
                num_training_steps=num_training_steps,
1146
            )
1147
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
1148

1149
    def num_examples(self, dataloader: DataLoader) -> int:
1150
        """
1151
1152
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
1153
        """
1154
        try:
1155
1156
1157
1158
            dataset = dataloader.dataset
            # Special case for IterableDatasetShard, we need to dig deeper
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
1159
1160
1161
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
1162

1163
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
1164
        """HP search setup code"""
1165
1166
        self._trial = trial

1167
1168
        if self.hp_search_backend is None or trial is None:
            return
1169
1170
1171
1172
1173
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
1174
1175
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
1176
1177
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
1178

1179
1180
        for key, value in params.items():
            if not hasattr(self.args, key):
1181
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1182
1183
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
1184
                )
1185
                continue
1186
1187
1188
1189
1190
1191
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1192
            logger.info(f"Trial: {trial.params}")
1193
1194
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1195
1196
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1197
1198
        if self.args.deepspeed:
            # Rebuild the deepspeed config to reflect the updated training parameters
1199
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1200

1201
1202
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1203

1204
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
1205
1206
        if self.hp_search_backend is None or trial is None:
            return
1207
        self.objective = self.compute_objective(metrics.copy())
1208
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1209
1210
            import optuna

1211
            trial.report(self.objective, step)
1212
            if trial.should_prune():
1213
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1214
1215
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1216
1217
            from ray import tune

1218
            if self.control.should_save:
1219
                self._tune_save_checkpoint()
1220
1221
            tune.report(objective=self.objective, **metrics)

1222
    def _tune_save_checkpoint(self):
1223
1224
        from ray import tune

1225
1226
        if not self.use_tune_checkpoints:
            return
1227
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1228
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1229
            self.save_model(output_dir, _internal_call=True)
1230
            if self.args.should_save:
1231
1232
1233
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1234

1235
    def call_model_init(self, trial=None):
1236
        model_init_argcount = number_of_arguments(self.model_init)
1237
1238
1239
1240
1241
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1242
1243
1244
1245
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1246
1247
1248

        return model

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
    def torch_jit_model_eval(self, model, dataloader, training=False):
        if not training:
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            jit_inputs = []
            example_batch = next(iter(dataloader))
            for key in example_batch:
                example_tensor = torch.ones_like(example_batch[key])
                jit_inputs.append(example_tensor)
            jit_inputs = tuple(jit_inputs)
            try:
                jit_model = model.eval()
                with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
                    jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
                jit_model = torch.jit.freeze(jit_model)
                jit_model(**example_batch)
                model = jit_model
            except (RuntimeError, TypeError) as e:
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        return model

1272
1273
1274
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
1275
1276
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            model.eval()
            model = ipex.optimize(model, dtype=dtype, level="O1")
        else:
            if not model.training:
                model.train()
1287
1288
1289
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )
1290
1291
1292

        return model

1293
    def _wrap_model(self, model, training=True, dataloader=None):
1294
1295
1296
1297
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

1298
1299
1300
        if self.args.jit_mode_eval:
            model = self.torch_jit_model_eval(model, dataloader, training)

Sylvain Gugger's avatar
Sylvain Gugger committed
1301
1302
1303
1304
1305
1306
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1307
1308
        # already initialized its own DDP and AMP
        if self.deepspeed:
1309
            return self.deepspeed
1310

1311
1312
1313
1314
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1315
1316
1317
1318
1319
1320
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
1321
            model = nn.DataParallel(model)
1322
1323
1324
1325
1326
1327
1328

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1329
1330
1331
1332
1333
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1334
                mixed_precision = self.args.fp16 or self.args.bf16
1335
1336
1337
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1338
1339
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1340
                self.model = model = FullyShardedDDP(
1341
1342
1343
1344
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1345
                ).to(self.args.device)
1346
        # Distributed training using PyTorch FSDP
1347
        elif self.fsdp is not None:
1348
1349
1350
            # PyTorch FSDP!
            from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
            from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
1351
1352
            from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
            from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362

            if FSDPOption.OFFLOAD in self.args.fsdp:
                cpu_offload = CPUOffload(offload_params=True)
            else:
                cpu_offload = CPUOffload(offload_params=False)

            auto_wrap_policy = None
            if FSDPOption.AUTO_WRAP in self.args.fsdp:
                if self.args.fsdp_min_num_params > 0:
                    auto_wrap_policy = functools.partial(
1363
                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
1364
                    )
1365
1366
1367
1368
                elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
                    transformer_cls_to_wrap = get_module_class_from_name(
                        model, self.args.fsdp_transformer_layer_cls_to_wrap
                    )
1369
1370
                    if transformer_cls_to_wrap is None:
                        raise Exception("Could not find the transformer layer class to wrap in the model.")
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
                    auto_wrap_policy = functools.partial(
                        transformer_auto_wrap_policy,
                        # Transformer layer class to wrap
                        transformer_layer_cls={transformer_cls_to_wrap},
                    )
            mixed_precision_policy = None
            dtype = None
            if self.args.fp16:
                dtype = torch.float16
            elif self.args.bf16:
                dtype = torch.bfloat16
            if dtype is not None:
                mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
1384
1385
1386
            if type(model) != FSDP:
                # XXX: Breaking the self.model convention but I see no way around it for now.
                self.model = model = FSDP(
1387
1388
1389
1390
1391
                    model,
                    sharding_strategy=self.fsdp,
                    cpu_offload=cpu_offload,
                    auto_wrap_policy=auto_wrap_policy,
                    mixed_precision=mixed_precision_policy,
1392
1393
1394
                )
                if FSDPOption.OFFLOAD not in self.args.fsdp:
                    model.to(self.args.device)
Sylvain Gugger's avatar
Sylvain Gugger committed
1395
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1396
1397
1398
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1399
        elif self.args.local_rank != -1:
1400
            kwargs = {}
1401
            if self.args.ddp_find_unused_parameters is not None:
1402
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1403
1404
1405
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1406
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1407
            else:
1408
1409
1410
1411
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1412
            model = nn.parallel.DistributedDataParallel(
1413
                model,
1414
1415
                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
1416
                **kwargs,
1417
1418
1419
1420
            )

        return model

1421
1422
    def train(
        self,
1423
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1424
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1425
        ignore_keys_for_eval: Optional[List[str]] = None,
1426
        **kwargs,
1427
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1428
1429
1430
1431
        """
        Main training entry point.

        Args:
1432
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1433
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1434
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1435
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1436
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1437
                The trial run or the hyperparameter dictionary for hyperparameter search.
1438
            ignore_keys_for_eval (`List[str]`, *optional*)
1439
1440
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1441
1442
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1443
        """
1444
1445
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1446
1447
1448
1449

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1450
1451
        args = self.args

1452
1453
        self.is_in_train = True

1454
1455
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1456
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1457
            self._move_model_to_device(self.model, args.device)
1458

1459
1460
1461
1462
1463
1464
1465
1466
1467
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1468
1469
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)
1470
        self._train_batch_size = self.args.train_batch_size
Sylvain Gugger's avatar
Sylvain Gugger committed
1471

1472
        # Model re-init
1473
        model_reloaded = False
1474
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1475
            # Seed must be set before instantiating the model when using model_init.
1476
            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
1477
1478
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1479
1480
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1481

1482
        # Load potential model checkpoint
1483
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1484
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1485
            if resume_from_checkpoint is None:
1486
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1487

1488
        if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
1489
            self._load_from_checkpoint(resume_from_checkpoint)
1490

1491
1492
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1493
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1494
                self._move_model_to_device(self.model, args.device)
1495
1496
            self.model_wrapped = self.model

1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
        inner_training_loop = find_executable_batch_size(
            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
        )
        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self._train_batch_size = batch_size
1511
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1512
        train_dataloader = self.get_train_dataloader()
1513
1514
1515
1516
1517

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1518
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1519
1520
1521
1522
1523

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1524
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1525
            num_examples = self.num_examples(train_dataloader)
1526
1527
1528
1529
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1530
                )
1531
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1532
1533
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1534
            else:
1535
1536
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1537
1538
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1539
            max_steps = args.max_steps
1540
1541
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1542
            num_update_steps_per_epoch = max_steps
1543
            num_examples = total_train_batch_size * args.max_steps
1544
            num_train_samples = args.max_steps * total_train_batch_size
1545
1546
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1547
1548
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
1549
            )
Julien Chaumond's avatar
Julien Chaumond committed
1550

1551
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1552
1553
1554
1555
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1556
1557
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
1558
1559
1560
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1561

1562
        delay_optimizer_creation = (
1563
1564
1565
1566
            self.sharded_ddp is not None
            and self.sharded_ddp != ShardedDDPOption.SIMPLE
            or is_sagemaker_mp_enabled()
            or self.fsdp is not None
1567
        )
1568
        if args.deepspeed:
1569
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
1570
1571
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
1572
1573
1574
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
1575
1576
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
1577
        elif not delay_optimizer_creation:
1578
1579
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1580
        self.state = TrainerState()
1581
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1582

1583
1584
1585
1586
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1587
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1588

1589
1590
1591
        if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint, model)

1592
1593
1594
1595
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1596
1597
1598
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1599
1600
1601
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1602
1603
        # important: at this point:
        # self.model         is the Transformers Model
1604
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1605

Julien Chaumond's avatar
Julien Chaumond committed
1606
1607
        # Train!
        logger.info("***** Running training *****")
1608
1609
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
1610
        logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
1611
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
1612
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1613
        logger.info(f"  Total optimization steps = {max_steps}")
1614
1615
1616
        logger.info(
            f"  Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
        )
Julien Chaumond's avatar
Julien Chaumond committed
1617

1618
        self.state.epoch = 0
1619
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1620
1621
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1622
        steps_trained_progress_bar = None
1623

Julien Chaumond's avatar
Julien Chaumond committed
1624
        # Check if continuing training from a checkpoint
1625
        if resume_from_checkpoint is not None and os.path.isfile(
1626
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1627
        ):
1628
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1629
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1630
            if not args.ignore_data_skip:
1631
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1632
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1633
1634
            else:
                steps_trained_in_current_epoch = 0
1635
1636

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1637
1638
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1639
            if not args.ignore_data_skip:
1640
1641
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
1642
1643
                    "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
                    "flag to your launch command, but you will resume the training on data already seen by your model."
1644
                )
1645
1646
1647
                if self.is_local_process_zero() and not args.disable_tqdm:
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1648

Sylvain Gugger's avatar
Sylvain Gugger committed
1649
1650
1651
1652
1653
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1654
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
1655
1656
1657
1658
1659
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1660
1661
1662
1663
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1664
1665
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1666

1667
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1668
        tr_loss = torch.tensor(0.0).to(args.device)
1669
1670
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1671
        self._globalstep_last_logged = self.state.global_step
Julien Chaumond's avatar
Julien Chaumond committed
1672
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1673

1674
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1675

1676
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1677
        if not args.ignore_data_skip:
1678
            for epoch in range(epochs_trained):
1679
1680
1681
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
1682
                if is_torch_less_than_1_11 or not is_random_sampler:
1683
1684
1685
1686
1687
1688
1689
1690
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1691

1692
        for epoch in range(epochs_trained, num_train_epochs):
1693
1694
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1695
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1696
                train_dataloader.dataset.set_epoch(epoch)
1697

1698
            if is_torch_tpu_available():
1699
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1700
                epoch_iterator = parallel_loader
1701
            else:
1702
                epoch_iterator = train_dataloader
1703

1704
            # Reset the past mems state at the beginning of each epoch if necessary.
1705
            if args.past_index >= 0:
1706
1707
                self._past = None

1708
            steps_in_epoch = (
1709
1710
1711
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1712
            )
1713
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1714

1715
1716
1717
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1718
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1719
1720
1721
1722
1723
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1724
1725
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1726
1727
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1728
                    continue
1729
1730
1731
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1732

1733
1734
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1735

1736
                if (
1737
1738
1739
                    ((step + 1) % args.gradient_accumulation_steps != 0)
                    and args.local_rank != -1
                    and args._no_sync_in_gradient_accumulation
1740
                ):
1741
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1742
                    with model.no_sync():
1743
                        tr_loss_step = self.training_step(model, inputs)
1744
                else:
1745
1746
                    tr_loss_step = self.training_step(model, inputs)

1747
1748
1749
1750
1751
1752
1753
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1754
1755
1756
                else:
                    tr_loss += tr_loss_step

1757
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1758

1759
1760
1761
1762
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

1763
                if (step + 1) % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1764
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1765
                    steps_in_epoch <= args.gradient_accumulation_steps
1766
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1767
                ):
1768
                    # Gradient clipping
1769
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
1770
1771
                        # deepspeed does its own clipping

1772
                        if self.do_grad_scaling:
1773
1774
1775
1776
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1777
1778
1779
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

1780
1781
1782
                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif hasattr(self.optimizer, "clip_grad_norm"):
1783
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1784
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1785
1786
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1787
                            model.clip_grad_norm_(args.max_grad_norm)
1788
1789
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
1790
                            nn.utils.clip_grad_norm_(
1791
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
1792
                                args.max_grad_norm,
1793
1794
1795
                            )

                    # Optimizer step
1796
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
1797
                    if self.deepspeed:
1798
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
1799
                    elif is_torch_tpu_available():
1800
1801
1802
1803
1804
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1805
                    elif self.do_grad_scaling:
1806
                        scale_before = self.scaler.get_scale()
1807
                        self.scaler.step(self.optimizer)
1808
                        self.scaler.update()
1809
1810
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1811
                    else:
1812
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
1813

1814
                    if optimizer_was_run and not self.deepspeed:
1815
1816
                        self.lr_scheduler.step()

Julien Chaumond's avatar
Julien Chaumond committed
1817
                    model.zero_grad()
1818
                    self.state.global_step += 1
1819
                    self.state.epoch = epoch + (step + 1) / steps_in_epoch
1820
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1821

1822
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
1823
1824
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1825

Sylvain Gugger's avatar
Sylvain Gugger committed
1826
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1827
                    break
1828
1829
            if step < 0:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
1830
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
1831
1832
1833
1834
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
1835

1836
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
1837
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
1838

1839
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
1840
1841
1842
1843
1844
1845
1846
1847
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1848
            if self.control.should_training_stop:
1849
                break
Julien Chaumond's avatar
Julien Chaumond committed
1850

1851
        if args.past_index and hasattr(self, "_past"):
1852
1853
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1854
1855

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
1856
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
1857
1858
1859
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
1860
            elif args.local_rank != -1:
1861
                dist.barrier()
1862
1863
            elif is_sagemaker_mp_enabled():
                smp.barrier()
1864

1865
            self._load_best_model()
1866

1867
1868
1869
1870
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

1871
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
1872
1873
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
1874
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
1875

1876
        self.is_in_train = False
1877

1878
1879
        self._memory_tracker.stop_and_update_metrics(metrics)

1880
1881
        self.log(metrics)

raghavanone's avatar
raghavanone committed
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint.
        if self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
            for checkpoint in checkpoints_sorted:
                if checkpoint != self.state.best_model_checkpoint:
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

1892
1893
1894
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
1895

raghavanone's avatar
raghavanone committed
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                from ray import tune

                run_id = tune.get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            run_dir = self.args.output_dir
        return run_dir

1916
1917
1918
1919
1920
    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):

        if model is None:
            model = self.model

1921
1922
1923
        if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
            os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        ):
1924
1925
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

1926
        logger.info(f"Loading model from {resume_from_checkpoint}.")
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940

        if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
            config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

        if self.args.deepspeed:
            # will be resumed in deepspeed_init
            pass
1941
        elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
1942
            # If the model is on the GPU, it still works!
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
            if is_sagemaker_mp_enabled():
                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
                    # If the 'user_content.pt' file exists, load with the new smp api.
                    # Checkpoint must have been saved with the new smp api.
                    smp.resume_from_checkpoint(
                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
                    )
                else:
                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                    # Checkpoint must have been saved with the old smp api.
                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
                        logger.warning(
                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
                        )
                    state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
                    state_dict["_smp_is_partial"] = False
                    load_result = model.load_state_dict(state_dict, strict=True)
                    # release memory
                    del state_dict
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
1966
1967
1968
                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                # which takes *args instead of **kwargs
                load_result = model.load_state_dict(state_dict, False)
1969
1970
                # release memory
                del state_dict
1971
                self._issue_warnings_after_load(load_result)
1972
1973
        else:
            # We load the sharded checkpoint
1974
1975
            load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
            if not is_sagemaker_mp_enabled():
1976
                self._issue_warnings_after_load(load_result)
1977
1978
1979
1980

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
1981
        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
1982
1983
        if os.path.exists(best_model_path):
            if self.deepspeed:
1984
1985
1986
1987
1988
1989

                if self.model_wrapped is not None:
                    # this removes the pre-hooks from the previous engine
                    self.model_wrapped.destroy()
                    self.model_wrapped = None

1990
                # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
1991
1992
1993
1994
1995
                deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
                    self,
                    num_training_steps=self.args.max_steps,
                    resume_from_checkpoint=self.state.best_model_checkpoint,
                )
1996
1997
1998
1999
2000
2001
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
                self.optimizer = optimizer
                self.lr_scheduler = lr_scheduler
            else:
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
                if is_sagemaker_mp_enabled():
                    if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
                        # If the 'user_content.pt' file exists, load with the new smp api.
                        # Checkpoint must have been saved with the new smp api.
                        smp.resume_from_checkpoint(
                            path=self.state.best_model_checkpoint,
                            tag=WEIGHTS_NAME,
                            partial=False,
                            load_optimizer=False,
                        )
                    else:
                        # If the 'user_content.pt' file does NOT exist, load with the old smp api.
                        # Checkpoint must have been saved with the old smp api.
                        state_dict = torch.load(best_model_path, map_location="cpu")
                        state_dict["_smp_is_partial"] = False
                        load_result = model.load_state_dict(state_dict, strict=True)
                else:
                    # We load the model state dict on the CPU to avoid an OOM error.
                    state_dict = torch.load(best_model_path, map_location="cpu")
                    # If the model is on the GPU, it still works!
2022
2023
2024
                    # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
                    # which takes *args instead of **kwargs
                    load_result = model.load_state_dict(state_dict, False)
2025
                if not is_sagemaker_mp_enabled():
2026
                    self._issue_warnings_after_load(load_result)
2027
        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
2028
2029
2030
2031
            load_result = load_sharded_checkpoint(
                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
            )
            if not is_sagemaker_mp_enabled():
2032
                self._issue_warnings_after_load(load_result)
2033
2034
2035
2036
2037
2038
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

2039
    def _issue_warnings_after_load(self, load_result):
2040
2041

        if len(load_result.missing_keys) != 0:
2042
2043
2044
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
2045
2046
                self.model.tie_weights()
            else:
2047
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
2048
        if len(load_result.unexpected_keys) != 0:
2049
2050
2051
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
2052

2053
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
2054
        if self.control.should_log:
2055
2056
2057
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
2058
            logs: Dict[str, float] = {}
2059
2060
2061
2062

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

2063
2064
2065
            # reset tr_loss to zero
            tr_loss -= tr_loss

2066
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
2067
            logs["learning_rate"] = self._get_learning_rate()
2068

2069
            self._total_loss_scalar += tr_loss_scalar
2070
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
2071
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
2072
2073
2074
2075
2076

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
2077
2078
2079
2080
2081
2082
2083
2084
2085
            if isinstance(self.eval_dataset, dict):
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
                    metrics = self.evaluate(
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2086
            self._report_to_hp_search(trial, self.state.global_step, metrics)
2087

Sylvain Gugger's avatar
Sylvain Gugger committed
2088
2089
2090
2091
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

2092
2093
2094
2095
2096
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

2097
2098
2099
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
2100
            if not os.path.isfile(rng_file):
2101
                logger.info(
2102
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
2103
2104
2105
2106
2107
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
2108
            if not os.path.isfile(rng_file):
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
            if self.args.local_rank != -1:
                torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
            else:
2123
2124
2125
                try:
                    torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
                except Exception as e:
2126
                    logger.info(
2127
2128
2129
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
2130
2131
2132
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
2133
    def _save_checkpoint(self, model, trial, metrics=None):
2134
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
2135
        # want to save except FullyShardedDDP.
2136
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
2137

2138
        # Save model checkpoint
2139
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
2140

raghavanone's avatar
raghavanone committed
2141
        if self.hp_search_backend is None and trial is None:
2142
            self.store_flos()
2143

raghavanone's avatar
raghavanone committed
2144
        run_dir = self._get_output_dir(trial=trial)
2145
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
2146
        self.save_model(output_dir, _internal_call=True)
2147
        if self.deepspeed:
2148
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
2149
            # config `stage3_gather_16bit_weights_on_model_save` is True
2150
            self.deepspeed.save_checkpoint(output_dir)
2151
2152

        # Save optimizer and scheduler
2153
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
2154
            self.optimizer.consolidate_state_dict()
2155

2156
2157
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
2158
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2159
            with warnings.catch_warnings(record=True) as caught_warnings:
2160
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2161
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
2162
        elif is_sagemaker_mp_enabled():
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
            if self.args.should_save:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
                if self.do_grad_scaling:
                    torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2178
        elif self.args.should_save and not self.deepspeed:
2179
            # deepspeed.save_checkpoint above saves model/optim/sched
2180
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2181
            with warnings.catch_warnings(record=True) as caught_warnings:
2182
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
2183
            reissue_pt_warnings(caught_warnings)
2184
            if self.do_grad_scaling:
2185
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2186
2187

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
2188
        if metrics is not None and self.args.metric_for_best_model is not None:
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
2204
        if self.args.should_save:
2205
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
2206

2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.local_rank == -1:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

2223
2224
2225
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
2226

2227
        if self.args.world_size <= 1:
2228
2229
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
2230
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
2231

2232
2233
2234
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

2235
        # Maybe delete some older checkpoints.
2236
        if self.args.should_save:
2237
2238
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

2239
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
2240
        """If optimizer and scheduler states exist, load them."""
2241
        if checkpoint is None:
2242
2243
            return

2244
        if self.deepspeed:
2245
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
2246
2247
            return

2248
2249
2250
2251
2252
2253
        checkpoint_file_exists = (
            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
            if is_sagemaker_mp_enabled()
            else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
        )
        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
Sylvain Gugger's avatar
Sylvain Gugger committed
2254
2255
2256
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
2257
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2258
                with warnings.catch_warnings(record=True) as caught_warnings:
2259
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
2260
2261
2262
2263
2264
2265
2266
2267
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2268
                map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
2269
                if is_sagemaker_mp_enabled():
2270
2271
2272
2273
                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
                        # Optimizer checkpoint was saved with smp >= 1.10
                        def opt_load_hook(mod, opt):
                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2274

2275
2276
2277
2278
2279
2280
2281
2282
2283
                    else:
                        # Optimizer checkpoint was saved with smp < 1.10
                        def opt_load_hook(mod, opt):
                            if IS_SAGEMAKER_MP_POST_1_10:
                                opt.load_state_dict(
                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
                                )
                            else:
                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
2284
2285
2286
2287
2288
2289

                    self.model_wrapped.register_post_step_hook(opt_load_hook)
                else:
                    self.optimizer.load_state_dict(
                        torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
2290
                with warnings.catch_warnings(record=True) as caught_warnings:
2291
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2292
                reissue_pt_warnings(caught_warnings)
2293
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
2294
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
2295

2296
2297
2298
2299
2300
2301
2302
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
2303
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
2304
        **kwargs,
2305
2306
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2307
2308
2309
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
2310

2311
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
2312

Sylvain Gugger's avatar
Sylvain Gugger committed
2313
2314
2315
2316
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
2317
2318

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
2319

2320
        Args:
2321
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
2322
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
2323
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
2324
2325
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2326
2327
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
2328
            n_trials (`int`, *optional*, defaults to 100):
2329
                The number of trial runs to test.
2330
            direction (`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2331
2332
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
2333
            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
2334
2335
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
2336
2337
2338
            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
                A function that defines the trial/run name. Will default to None.
            kwargs (`Dict[str, Any]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2339
2340
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
2341

Sylvain Gugger's avatar
Sylvain Gugger committed
2342
2343
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
2344
2345
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
2346
2347

        Returns:
2348
            [`trainer_utils.BestRun`]: All the information about the best run.
2349
2350
2351
2352
2353
2354
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
2355
2356
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
2357
                    "To install sigopt run `pip install sigopt`."
2358
2359
2360
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
2361
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
2362
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
2363
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
2364
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
2365
            )
2366
2367
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
2368
2369
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
2370
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
2371
2372
2373
2374
2375
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

2376
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
2377
        self.hp_name = hp_name
2378
2379
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

2380
2381
2382
2383
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
2384
            HPSearchBackend.WANDB: run_hp_search_wandb,
2385
2386
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
2387
2388
2389
2390

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
2391
    def log(self, logs: Dict[str, float]) -> None:
2392
        """
2393
        Log `logs` on the various objects watching training.
2394
2395
2396
2397

        Subclass and override this method to inject custom behavior.

        Args:
2398
            logs (`Dict[str, float]`):
2399
2400
                The values to log.
        """
2401
        if self.state.epoch is not None:
2402
            logs["epoch"] = round(self.state.epoch, 2)
2403

2404
2405
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
2406
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
2407

2408
2409
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
2410
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2411
        """
2412
2413
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
            kwargs = dict(device=self.args.device)
            if self.deepspeed and data.dtype != torch.int64:
                # NLP models inputs are int64 and those get adjusted to the right dtype of the
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
                kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2426
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2427
        """
2428
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2429
2430
        handling potential state.
        """
2431
        inputs = self._prepare_input(inputs)
2432
2433
2434
2435
2436
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2437
2438
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2439

2440
2441
        return inputs

2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
        return ContextManagers(
            [
                self.torchdynamo_smart_context_manager(),
                self.autocast_smart_context_manager(),
            ]
        )

    def torchdynamo_smart_context_manager(self):
        """
        A helper wrapper that creates an appropriate context manager for `torchdynamo`.
        """
2457
        return self.ctx_manager_torchdynamo
2458

2459
2460
    def autocast_smart_context_manager(self):
        """
2461
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2462
2463
        arguments, depending on the situation.
        """
2464
        if self.use_cuda_amp or self.use_cpu_amp:
2465
            if is_torch_greater_or_equal_than_1_10:
2466
2467
2468
2469
2470
                ctx_manager = (
                    torch.cpu.amp.autocast(dtype=self.amp_dtype)
                    if self.use_cpu_amp
                    else torch.cuda.amp.autocast(dtype=self.amp_dtype)
                )
2471
            else:
2472
                ctx_manager = torch.cuda.amp.autocast()
2473
2474
2475
2476
2477
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2478
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2479
        """
2480
        Perform a training step on a batch of inputs.
2481
2482
2483
2484

        Subclass and override to inject custom behavior.

        Args:
2485
            model (`nn.Module`):
2486
                The model to train.
2487
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2488
2489
2490
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2491
                argument `labels`. Check your model's documentation for all accepted arguments.
2492
2493

        Return:
2494
            `torch.Tensor`: The tensor with training loss on this batch.
2495
2496
        """
        model.train()
2497
        inputs = self._prepare_inputs(inputs)
2498

Sylvain Gugger's avatar
Sylvain Gugger committed
2499
        if is_sagemaker_mp_enabled():
2500
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Sylvain Gugger's avatar
Sylvain Gugger committed
2501
2502
            return loss_mb.reduce_mean().detach().to(self.args.device)

2503
        with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2504
            loss = self.compute_loss(model, inputs)
2505

Julien Chaumond's avatar
Julien Chaumond committed
2506
2507
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2508

2509
2510
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
2511
2512
            loss = loss / self.args.gradient_accumulation_steps

2513
        if self.do_grad_scaling:
2514
            self.scaler.scale(loss).backward()
2515
        elif self.use_apex:
2516
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2517
                scaled_loss.backward()
2518
        elif self.deepspeed:
2519
2520
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2521
2522
2523
        else:
            loss.backward()

2524
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
2525

2526
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2527
2528
2529
2530
2531
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2532
2533
2534
2535
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2536
2537
        outputs = model(**inputs)
        # Save past state if it exists
2538
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2539
2540
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2541

2542
        if labels is not None:
2543
2544
2545
2546
            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2547
        else:
2548
2549
2550
2551
2552
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
2553
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2554
2555
2556
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2557

2558
2559
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2560
2561
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2562
        """
2563
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2564

2565
2566
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2567
        Whether or not this process is the global main process (when training in a distributed fashion on several
2568
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2569
        """
2570
2571
2572
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2573
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2574
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2575
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2576

Sylvain Gugger's avatar
Sylvain Gugger committed
2577
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2578
        """
2579
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2580

2581
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2582
        """
2583
2584
2585
2586

        if output_dir is None:
            output_dir = self.args.output_dir

2587
        if is_torch_tpu_available():
2588
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2589
2590
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
2591
            os.makedirs(output_dir, exist_ok=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2592
            state_dict = self.model_wrapped.state_dict()
2593
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2594
                self._save(output_dir, state_dict=state_dict)
2595
2596
2597
            if IS_SAGEMAKER_MP_POST_1_10:
                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
                Path(os.path.join(output_dir, "user_content.pt")).touch()
2598
        elif (
2599
2600
2601
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
            or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
            or self.fsdp is not None
2602
2603
        ):
            state_dict = self.model.state_dict()
2604

2605
            if self.args.should_save:
2606
                self._save(output_dir, state_dict=state_dict)
2607
2608
2609
        elif self.deepspeed:

            # this takes care of everything as long as we aren't under zero3
2610
            if self.args.should_save:
2611
2612
2613
2614
2615
2616
2617
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2618
                if self.args.should_save:
2619
2620
2621
2622
2623
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2624
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2625
2626
                # if false it will not be saved.
                # This must be called on all ranks
2627
                if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
2628
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
2629
2630
2631
                        "deepspeed.save_16bit_model didn't save the model, since"
                        " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                        " zero_to_fp32.py to recover weights"
2632
2633
                    )
                    self.deepspeed.save_checkpoint(output_dir)
2634

2635
        elif self.args.should_save:
2636
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2637

Sylvain Gugger's avatar
Sylvain Gugger committed
2638
2639
2640
2641
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2642
2643
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2644
        logger.info(f"Saving model checkpoint to {output_dir}")
2645
2646
2647

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2648
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2649
2650
2651
2652

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2653
        if not isinstance(self.model, PreTrainedModel):
2654
2655
2656
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2657
                    is_main_process=self.args.should_save,
2658
2659
2660
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2661
2662
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2663
2664
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2665
        else:
2666
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2667
        if self.tokenizer is not None and self.args.should_save:
2668
            self.tokenizer.save_pretrained(output_dir)
2669

2670
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2671
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2672
2673
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2674
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2675
2676
2677
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2678
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2679
2680
2681
                if state_dict is None:
                    state_dict = self.model.state_dict()
                unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
2682
2683
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2684
2685
                if state_dict is None:
                    state_dict = self.model.state_dict()
2686
                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2687
        else:
2688
            self.model.save_pretrained(output_dir, state_dict=state_dict)
2689
        if self.tokenizer is not None:
2690
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2691
2692

        # Good practice: save your training arguments together with the trained model
2693
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2694

2695
    def store_flos(self):
2696
        # Storing the number of floating-point operations that went into the model
2697
        if self.args.local_rank != -1:
2698
2699
2700
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2701
2702
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2703
            self.state.total_flos += self.current_flos
2704
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2705

2706
2707
2708
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2709
2710
        ordering_and_checkpoint_path = []

2711
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2712
2713
2714
2715
2716

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2717
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2718
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2719
2720
2721
2722
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2723
2724
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2725
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2726
2727
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2728
2729
        return checkpoints_sorted

2730
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2731
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2732
2733
2734
            return

        # Check if we should delete older checkpoint(s)
2735
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2736
2737
2738
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2739
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2750
2751
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2752
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
Julien Chaumond's avatar
Julien Chaumond committed
2753
2754
            shutil.rmtree(checkpoint)

2755
    def evaluate(
2756
2757
2758
2759
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2760
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2761
        """
2762
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2763

Sylvain Gugger's avatar
Sylvain Gugger committed
2764
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2765
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2766

2767
2768
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2769
        Args:
2770
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2771
2772
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
Sylvain Gugger's avatar
Sylvain Gugger committed
2773
                method.
2774
            ignore_keys (`Lst[str]`, *optional*):
2775
2776
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2777
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2778
2779
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2780

Julien Chaumond's avatar
Julien Chaumond committed
2781
        Returns:
2782
2783
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2784
        """
2785
2786
2787
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2788
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2789
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2790

2791
2792
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2793
2794
2795
2796
2797
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
2798
            ignore_keys=ignore_keys,
2799
            metric_key_prefix=metric_key_prefix,
2800
        )
Lysandre Debut's avatar
Lysandre Debut committed
2801

2802
2803
2804
2805
2806
2807
2808
2809
2810
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2811

2812
        self.log(output.metrics)
2813

2814
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
2815
2816
2817
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
2818
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
2819
2820
2821

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
2822
2823
        return output.metrics

2824
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
2825
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
2826
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2827
        """
2828
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2829

Sylvain Gugger's avatar
Sylvain Gugger committed
2830
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
2831
        will also return metrics, like in `evaluate()`.
2832
2833

        Args:
2834
2835
2836
2837
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
            ignore_keys (`Lst[str]`, *optional*):
2838
2839
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2840
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
2841
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
2842
                "test_bleu" if the prefix is "test" (default)
2843

2844
2845
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
2846
2847
2848
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
2849

2850
        </Tip>
2851

2852
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
2853

2854
2855
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
2856
2857
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
2858
        """
2859
2860
2861
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2862
        test_dataloader = self.get_test_dataloader(test_dataset)
2863
        start_time = time.time()
2864

2865
2866
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2867
2868
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
2869
2870
2871
2872
2873
2874
2875
2876
2877
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2878

2879
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
2880
2881
        self._memory_tracker.stop_and_update_metrics(output.metrics)

2882
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
2883

2884
    def evaluation_loop(
2885
2886
2887
2888
2889
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
2890
        metric_key_prefix: str = "eval",
2891
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2892
        """
2893
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
2894
2895
2896

        Works both with or without labels.
        """
2897
2898
2899
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
2900

2901
        # if eval is called w/o train init deepspeed here
2902
        if args.deepspeed and not self.deepspeed:
2903
2904
2905

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
2906
2907
2908
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
2909
2910
2911
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
2912

2913
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
2914

2915
2916
2917
2918
2919
2920
2921
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
2922

2923
        batch_size = self.args.eval_batch_size
2924

2925
        logger.info(f"***** Running {description} *****")
2926
        if has_length(dataloader):
2927
2928
2929
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
2930
        logger.info(f"  Batch size = {batch_size}")
2931

Julien Chaumond's avatar
Julien Chaumond committed
2932
2933
        model.eval()

2934
2935
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
2936
        eval_dataset = getattr(dataloader, "dataset", None)
2937

2938
        if is_torch_tpu_available():
2939
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
2940

2941
        if args.past_index >= 0:
2942
            self._past = None
2943

2944
2945
2946
2947
2948
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
2949
2950
        inputs_host = None

2951
2952
2953
2954
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
2955
        all_inputs = None
2956
2957
2958
2959
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
2960
        for step, inputs in enumerate(dataloader):
2961
2962
2963
2964
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
2965
2966
2967
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
2968
2969

            # Prediction step
2970
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
2971
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
2972

2973
2974
2975
            if is_torch_tpu_available():
                xm.mark_step()

2976
            # Update containers on host
2977
            if loss is not None:
2978
                losses = self._nested_gather(loss.repeat(batch_size))
2979
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
2980
            if labels is not None:
2981
2982
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
2983
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
2984
2985
2986
2987
2988
2989
2990
2991
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
2992
2993
2994
2995
2996
2997
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
2998
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2999

3000
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3001
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3002
3003
3004
3005
3006
3007
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3008
3009
3010
3011
3012
3013
3014
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
3015
3016
3017
3018
3019
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
3020
3021

                # Set back to None to begin a new accumulation
3022
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
3023

3024
        if args.past_index and hasattr(self, "_past"):
3025
3026
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
3027

3028
        # Gather all remaining tensors and put them back on the CPU
3029
3030
3031
3032
3033
3034
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
3035
3036
3037
3038
3039
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
3040
3041
3042
3043
3044
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
3045
        if has_length(eval_dataset):
3046
            num_samples = len(eval_dataset)
3047
3048
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
3049
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
3050
3051
            num_samples = eval_dataset.num_examples
        else:
3052
3053
3054
3055
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
3056
3057
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples
3058
3059
3060
3061
3062
3063
3064
3065
3066

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
3067
3068
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
3069
3070
3071

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
3072
3073
3074
3075
3076
3077
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
3078
3079
        else:
            metrics = {}
3080

3081
3082
3083
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

3084
3085
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
3086

3087
        # Prefix all keys with metric_key_prefix + '_'
3088
        for key in list(metrics.keys()):
3089
3090
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
3091

3092
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
3093

3094
    def _nested_gather(self, tensors, name=None):
3095
3096
3097
3098
3099
3100
3101
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
3102
3103
            if name is None:
                name = "nested_gather"
3104
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
3105
3106
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
3107
3108
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)
3109
        return tensors
3110

3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
        if tensor.shape[1] == max_size:
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
3143

3144
    def prediction_step(
3145
3146
3147
3148
3149
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
3150
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
3151
        """
Stas Bekman's avatar
Stas Bekman committed
3152
        Perform an evaluation step on `model` using `inputs`.
3153
3154
3155
3156

        Subclass and override to inject custom behavior.

        Args:
3157
            model (`nn.Module`):
3158
                The model to evaluate.
3159
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3160
3161
3162
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
3163
3164
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
3165
                Whether or not to return the loss only.
3166
            ignore_keys (`Lst[str]`, *optional*):
3167
3168
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
3169
3170

        Return:
3171
3172
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
3173
        """
3174
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
3175
        inputs = self._prepare_inputs(inputs)
3176
3177
3178
3179
3180
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
3181

3182
3183
3184
3185
3186
3187
3188
3189
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

3190
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
3203
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3204
3205
3206
3207
3208
3209
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
3210
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3211
                if has_labels:
3212
                    with self.compute_loss_context_manager():
3213
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
3214
                    loss = loss.mean().detach()
3215

Sylvain Gugger's avatar
Sylvain Gugger committed
3216
3217
3218
3219
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
3220
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
3221
                    loss = None
3222
                    with self.compute_loss_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
3223
3224
3225
3226
3227
3228
3229
3230
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
3231
3232
3233
3234

        if prediction_loss_only:
            return (loss, None, None)

3235
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
3236
3237
3238
3239
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
3240
3241
3242

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
3243
3244
3245
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
3246
3247

        Args:
3248
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
3249
3250
3251
                The inputs and targets of the model.

        Returns:
3252
            `int`: The number of floating-point operations.
3253
        """
3254
3255
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
3256
3257
        else:
            return 0
3258

3259
    def init_git_repo(self, at_init: bool = False):
3260
        """
3261
        Initializes a git repo in `self.args.hub_model_id`.
3262
3263
3264
3265
3266
3267

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
3268
        """
3269
        if not self.is_world_process_zero():
3270
            return
3271
3272
        use_auth_token = True if self.args.hub_token is None else self.args.hub_token
        if self.args.hub_model_id is None:
3273
            repo_name = Path(self.args.output_dir).absolute().name
3274
3275
        else:
            repo_name = self.args.hub_model_id
3276
3277
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
3278

3279
3280
3281
3282
3283
        try:
            self.repo = Repository(
                self.args.output_dir,
                clone_from=repo_name,
                use_auth_token=use_auth_token,
3284
                private=self.args.hub_private_repo,
3285
3286
            )
        except EnvironmentError:
3287
            if self.args.overwrite_output_dir and at_init:
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
                self.repo = Repository(
                    self.args.output_dir,
                    clone_from=repo_name,
                    use_auth_token=use_auth_token,
                )
            else:
                raise

        self.repo.git_pull()
3299
3300

        # By default, ignore the checkpoint folders
3301
3302
3303
3304
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
3305
3306
3307
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

3308
3309
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
3310
3311
3312
3313
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
3314
        tags: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3315
3316
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
3317
3318
3319
3320
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
3321
    ):
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
3347
3348
3349
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
3350
3351
3352
3353
3354
3355
3356
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
3357
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
3358
3359
3360
3361
3362
3363
3364
3365
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
3400
3401
3402
            _, self.push_in_progress = self.repo.push_to_hub(
                commit_message=commit_message, blocking=False, auto_lfs_prune=True
            )
3403
3404
3405
3406
3407
3408
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
3409
        """
3410
        Upload *self.model* and *self.tokenizer* to the 馃 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
3411
3412

        Parameters:
3413
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
3414
                Message to commit while pushing.
3415
3416
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
3417
            kwargs:
3418
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
3419
3420

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
3421
3422
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
3423
        """
3424
3425
3426
3427
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
3428

3429
        if self.args.should_save:
3430
3431
3432
3433
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
3434

3435
3436
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
3437
        self.save_model(_internal_call=True)
3438
3439
3440
3441
3442

        # Only push from one node.
        if not self.is_world_process_zero():
            return

3443
3444
3445
3446
3447
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

3448
3449
3450
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
3451
3452
3453
3454
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
3455
3456
3457
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
3458
3459
3460
3461
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
3462

3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        """
3476
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
3477
3478
3479

        Works both with or without labels.
        """
3480
3481
        args = self.args

3482
3483
3484
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

3485
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
3486
3487

        # if eval is called w/o train init deepspeed here
3488
        if args.deepspeed and not self.deepspeed:
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

3501
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
3502

3503
3504
3505
3506
3507
3508
3509
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3510
3511
3512
3513
3514
3515
3516
3517
3518

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3519
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3520

3521
        world_size = max(1, args.world_size)
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3532
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3533
3534
3535
3536

        model.eval()

        if is_torch_tpu_available():
3537
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3538

3539
        if args.past_index >= 0:
3540
3541
3542
3543
3544
3545
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3546
            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
3547

3548
3549
3550
3551
3552
3553
3554
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3555
3556
3557
3558
3559
3560
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3561
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3562
3563

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3564
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3565
3566
3567
3568
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3569
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3570
3571

                # Set back to None to begin a new accumulation
3572
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3573

3574
        if args.past_index and hasattr(self, "_past"):
3575
3576
3577
3578
3579
3580
3581
3582
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3583
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3584
3585
3586
3587

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3588
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3589
3590

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3591
3592
3593
3594
3595
3596
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)