trainer.py 145 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 馃 Transformers from scratch or finetune it on a new task.
"""

19
import contextlib
20
import inspect
21
import math
Julien Chaumond's avatar
Julien Chaumond committed
22
import os
23
import random
Julien Chaumond's avatar
Julien Chaumond committed
24
25
import re
import shutil
26
import sys
27
import time
28
import warnings
29
from collections.abc import Mapping
Julien Chaumond's avatar
Julien Chaumond committed
30
from pathlib import Path
31
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
32

33
34
from tqdm.auto import tqdm

Julien Chaumond's avatar
Julien Chaumond committed
35

36
37
# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
38
    default_hp_search_backend,
39
    get_reporting_integration_callbacks,
40
    hp_params,
41
    is_fairscale_available,
42
    is_optuna_available,
43
    is_ray_tune_available,
44
    is_sigopt_available,
45
    is_wandb_available,
46
47
    run_hp_search_optuna,
    run_hp_search_ray,
48
    run_hp_search_sigopt,
49
    run_hp_search_wandb,
50
)
51
52
53

import numpy as np
import torch
Lai Wei's avatar
Lai Wei committed
54
import torch.distributed as dist
55
56
from packaging import version
from torch import nn
57
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
58
59
from torch.utils.data.distributed import DistributedSampler

60
61
from huggingface_hub import Repository

62
63
from . import __version__
from .configuration_utils import PretrainedConfig
64
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
65
from .debug_utils import DebugOption, DebugUnderflowOverflow
66
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
67
from .dependency_versions_check import dep_version_check
Sylvain Gugger's avatar
Sylvain Gugger committed
68
from .modelcard import TrainingSummary
69
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
70
from .optimization import Adafactor, get_scheduler
71
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
72
73
74
75
76
77
78
79
80
81
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
82
    DistributedLengthGroupedSampler,
83
    DistributedSamplerWithLoop,
84
    DistributedTensorGatherer,
85
    IterableDatasetShard,
Sylvain Gugger's avatar
Sylvain Gugger committed
86
    LabelSmoother,
87
    LengthGroupedSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
88
    SequentialDistributedSampler,
89
    ShardSampler,
Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
    distributed_broadcast_scalars,
    distributed_concat,
92
    find_batch_size,
93
    get_parameter_names,
Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
96
    nested_concat,
    nested_detach,
    nested_numpify,
97
    nested_truncate,
Sylvain Gugger's avatar
Sylvain Gugger committed
98
99
100
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
101
102
103
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
104
    EvalLoopOutput,
105
106
    EvalPrediction,
    HPSearchBackend,
107
108
    HubStrategy,
    IntervalStrategy,
109
    PredictionOutput,
110
    ShardedDDPOption,
111
    TrainerMemoryTracker,
112
113
114
    TrainOutput,
    default_compute_objective,
    default_hp_space,
115
    denumpify_detensorize,
116
    get_last_checkpoint,
117
    has_length,
118
    number_of_arguments,
119
    set_seed,
120
    speed_metrics,
121
)
122
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
123
124
from .utils import (
    CONFIG_NAME,
125
    WEIGHTS_INDEX_NAME,
126
    WEIGHTS_NAME,
127
    find_labels,
128
129
130
131
132
133
134
135
136
    get_full_repo_name,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_tpu_available,
    logging,
)
Julien Chaumond's avatar
Julien Chaumond committed
137
138


139
_is_torch_generator_available = False
140
_is_native_amp_available = False
141

Sylvain Gugger's avatar
Sylvain Gugger committed
142
DEFAULT_CALLBACKS = [DefaultFlowCallback]
143
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
144

145
146
147
148
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
149

150
151
if is_apex_available():
    from apex import amp
152

153
if version.parse(torch.__version__) >= version.parse("1.6"):
154
    _is_torch_generator_available = True
155
    _is_native_amp_available = True
156
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
157

158
159
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
160

161
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
162
163
164
165
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

166
if is_fairscale_available():
167
    dep_version_check("fairscale")
168
    import fairscale
169
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
170
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
171
    from fairscale.nn.wrap import auto_wrap
172
173
174
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

175

Sylvain Gugger's avatar
Sylvain Gugger committed
176
177
178
179
180
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat

181

182
183
184
if TYPE_CHECKING:
    import optuna

Lysandre Debut's avatar
Lysandre Debut committed
185
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
186
187


188
189
190
191
192
193
194
195
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"


Julien Chaumond's avatar
Julien Chaumond committed
196
197
class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
198
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 馃 Transformers.
199
200

    Args:
201
202
        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
203

204
            <Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
205

Sylvain Gugger's avatar
Sylvain Gugger committed
206
207
208
            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
            your own models defined as `torch.nn.Module` as long as they work the same way as the 馃 Transformers
            models.
209
210
211
212

            </Tip>

        args ([`TrainingArguments`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
213
214
            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
215
        data_collator (`DataCollator`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
            default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
218
219
220
221
222
            [`DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
            The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
            `model.forward()` method are automatically removed.

Sylvain Gugger's avatar
Sylvain Gugger committed
223
224
225
226
227
            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
            sets the seed of the RNGs used.
228
229
230
231
        eval_dataset (`torch.utils.data.Dataset`, *optional*):
             The dataset to use for evaluation. If it is an `datasets.Dataset`, columns not accepted by the
             `model.forward()` method are automatically removed.
        tokenizer ([`PreTrainedTokenizerBase`], *optional*):
232
233
234
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
235
        model_init (`Callable[[], PreTrainedModel]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
236
237
            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
            from a new instance of the model as given by this function.
238

239
240
241
            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers, dropout probabilities etc).
242
        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
243
244
            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
            a dictionary string to metric values.
245
        callbacks (List of [`TrainerCallback`], *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
246
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
247
            detailed in [here](callback).
Sylvain Gugger's avatar
Sylvain Gugger committed
248

249
250
            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
251
252
            containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
253
254
255
256
257
258
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocess the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
259

260
261
    Important attributes:

Sylvain Gugger's avatar
Sylvain Gugger committed
262
263
        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
          subclass.
264
        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
265
          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
Sylvain Gugger's avatar
Sylvain Gugger committed
266
267
          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
268
269
        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
          data parallelism, this means some of the model layers are split on different GPUs).
270
        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
271
272
          to `False` if model parallel or deepspeed is used, or if the default
          `TrainingArguments.place_model_on_device` is overridden to return `False` .
Sylvain Gugger's avatar
Sylvain Gugger committed
273
274
        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
          in `train`)
275

Julien Chaumond's avatar
Julien Chaumond committed
276
277
    """

278
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
279

Julien Chaumond's avatar
Julien Chaumond committed
280
281
    def __init__(
        self,
282
        model: Union[PreTrainedModel, nn.Module] = None,
283
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
284
285
286
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
287
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
288
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
289
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
290
        callbacks: Optional[List[TrainerCallback]] = None,
291
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
292
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
Julien Chaumond's avatar
Julien Chaumond committed
293
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
294
        if args is None:
295
296
297
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
298
299
300
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
301
        self.hp_name = None
302
        self.deepspeed = None
303
        self.is_in_train = False
304

305
306
307
308
        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

309
        # set the correct log level depending on the node
310
        log_level = args.get_process_log_level()
311
312
        logging.set_verbosity(log_level)

313
314
315
        # force device and distributed setup init explicitly
        args._setup_devices

316
317
318
319
320
321
322
323
324
325
326
327
328
329
        if model is None:
            if model_init is not None:
                self.model_init = model_init
                model = self.call_model_init()
            else:
                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
                    "`Trainer` requires either a `model` or `model_init` argument, but not both. "
                    "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
                    FutureWarning,
                )
            self.model_init = model_init
330

331
332
333
334
335
        if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
            self.is_model_parallel = True
        else:
            self.is_model_parallel = False

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        # Setup Sharded DDP training
        self.sharded_ddp = None
        if len(args.sharded_ddp) > 0:
            if args.deepspeed:
                raise ValueError(
                    "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
                )

            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
                raise ImportError(
                    "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
                    f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
                )
            elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.SIMPLE
            elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
            elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
                self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

360
        # one place to sort out whether to place the model on device or not
361
362
363
364
        # postpone switching model to cuda when:
        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
        #    and we only use deepspeed for training at the moment
365
        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
366
        # 4. Sharded DDP - same as MP
367
        self.place_model_on_device = args.place_model_on_device
368
369
        if (
            self.is_model_parallel
370
            or args.deepspeed
371
            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
372
373
            or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
        ):
374
375
            self.place_model_on_device = False

376
377
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
378
379
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
380
        self.tokenizer = tokenizer
381

382
        if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
383
            self._move_model_to_device(model, args.device)
Stas Bekman's avatar
Stas Bekman committed
384
385
386

        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
        if self.is_model_parallel:
387
            self.args._n_gpu = 1
388
389
390
391
392

        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
        self.model_wrapped = model
        self.model = model

Julien Chaumond's avatar
Julien Chaumond committed
393
        self.compute_metrics = compute_metrics
394
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
395
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
396
397
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
398
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
Sylvain Gugger's avatar
Sylvain Gugger committed
399
400
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
401
402
403
404
405
406
407
        if (self.sharded_ddp is not None or args.deepspeed) and (
            self.optimizer is not None or self.lr_scheduler is not None
        ):
            raise RuntimeError(
                "Passing `optimizers` is not allowed if Fairscale or Deepspeed is enabled."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
408
409
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
410
411
412
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
413
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
414

415
416
417
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

418
419
        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
420
            self.init_git_repo(at_init=True)
421
422
423
424
425
426
            # In case of pull, we need to make sure every process has the latest.
            if is_torch_tpu_available():
                xm.rendezvous("init git repo")
            elif args.local_rank != -1:
                dist.barrier()

427
        if self.args.should_save:
Julien Chaumond's avatar
Julien Chaumond committed
428
            os.makedirs(self.args.output_dir, exist_ok=True)
429

430
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
431
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
432

433
434
435
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

436
        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
437
438
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")

439
440
441
442
443
444
445
        if (
            train_dataset is not None
            and isinstance(train_dataset, torch.utils.data.IterableDataset)
            and args.group_by_length
        ):
            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")

446
        self._signature_columns = None
447

448
449
450
        # Mixed precision setup
        self.use_apex = False
        self.use_amp = False
451

452
453
454
455
456
457
458
459
460
461
        if args.fp16 or args.bf16:
            if args.half_precision_backend == "auto":
                if _is_native_amp_available:
                    args.half_precision_backend = "amp"
                else:
                    if args.bf16:
                        raise ValueError("Tried to use `bf16` but native amp is not available")
                    else:
                        args.half_precision_backend = "apex"
            logger.info(f"Using {args.half_precision_backend} half precision backend")
462

463
464
465
        self.do_grad_scaling = False
        if (args.fp16 or args.bf16) and not args.deepspeed:  # deepspeed manages its own half precision
            if args.half_precision_backend == "amp":
466
                self.use_amp = True
467
468
                self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
                self.do_grad_scaling = True
469
470
471
472
                if is_sagemaker_mp_enabled():
                    self.scaler = smp.amp.GradScaler()
                elif self.sharded_ddp is not None:
                    self.scaler = ShardedGradScaler()
473
474
475
476
                elif is_torch_tpu_available():
                    from torch_xla.amp import GradScaler

                    self.scaler = GradScaler()
477
478
                else:
                    self.scaler = torch.cuda.amp.GradScaler()
479
480
481
482
483
484
485
            else:
                if not is_apex_available():
                    raise ImportError(
                        "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
                    )
                self.use_apex = True

486
487
488
489
490
491
492
        # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
        if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
            raise ValueError(
                "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
                "along 'max_grad_norm': 0 in your hyperparameters."
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
493
494
495
496
497
498
        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

499
500
501
502
503
        self.state = TrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
504
        self.control = TrainerControl()
505
506
507
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
508
        self.hp_search_backend = None
509
        self.use_tune_checkpoints = False
510
        default_label_names = find_labels(self.model.__class__)
511
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
512
513
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

514
515
516
        # very last
        self._memory_tracker.stop_and_update_metrics()

Sylvain Gugger's avatar
Sylvain Gugger committed
517
518
    def add_callback(self, callback):
        """
519
        Add a callback to the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
520
521

        Args:
522
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
523
524
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
Sylvain Gugger's avatar
Sylvain Gugger committed
525
526
527
528
529
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
530
        Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
Sylvain Gugger's avatar
Sylvain Gugger committed
531

532
        If the callback is not found, returns `None` (and no error is raised).
Sylvain Gugger's avatar
Sylvain Gugger committed
533
534

        Args:
535
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
536
537
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
538
539

        Returns:
540
            [`~transformer.TrainerCallback`]: The callback removed, if found.
Sylvain Gugger's avatar
Sylvain Gugger committed
541
542
543
544
545
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
546
        Remove a callback from the current list of [`~transformer.TrainerCallback`].
Sylvain Gugger's avatar
Sylvain Gugger committed
547
548

        Args:
549
           callback (`type` or [`~transformer.TrainerCallback`]):
Sylvain Gugger's avatar
Sylvain Gugger committed
550
551
               A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
Sylvain Gugger's avatar
Sylvain Gugger committed
552
553
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
554

Sylvain Gugger's avatar
Sylvain Gugger committed
555
556
557
558
559
560
    def _move_model_to_device(self, model, device):
        model = model.to(device)
        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

561
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
562
        if not self.args.remove_unused_columns:
563
            return dataset
564
565
566
567
568
569
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += ["label", "label_ids"]
570

571
        ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
572
573
574
575
576
        if len(ignored_columns) > 0:
            dset_description = "" if description is None else f"in the {description} set "
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
577
578
                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
                f" you can safely ignore this message."
579
            )
580

581
582
        columns = [k for k in self._signature_columns if k in dataset.column_names]

583
584
585
586
587
588
589
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)
590

591
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
592
        if self.train_dataset is None or not has_length(self.train_dataset):
593
            return None
594

595
596
597
        generator = None
        if self.args.world_size <= 1 and _is_torch_generator_available:
            generator = torch.Generator()
598
599
600
601
602
603
604
605
606
607
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
608

609
610
        # Build the sampler.
        if self.args.group_by_length:
611
612
613
614
615
616
617
618
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
619
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
620
            if self.args.world_size <= 1:
621
                return LengthGroupedSampler(
622
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
623
                    dataset=self.train_dataset,
624
625
626
                    lengths=lengths,
                    model_input_name=model_input_name,
                    generator=generator,
627
                )
628
629
            else:
                return DistributedLengthGroupedSampler(
630
                    self.args.train_batch_size * self.args.gradient_accumulation_steps,
631
                    dataset=self.train_dataset,
632
633
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
634
                    lengths=lengths,
635
                    model_input_name=model_input_name,
636
                    seed=seed,
637
638
639
                )

        else:
640
            if self.args.world_size <= 1:
641
642
                if _is_torch_generator_available:
                    return RandomSampler(self.train_dataset, generator=generator)
643
                return RandomSampler(self.train_dataset)
Sylvain Gugger's avatar
Sylvain Gugger committed
644
645
646
647
            elif (
                self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
                and not self.args.dataloader_drop_last
            ):
648
649
650
651
652
653
                # Use a loop for TPUs when drop_last is False to have all batches have the same size.
                return DistributedSamplerWithLoop(
                    self.train_dataset,
                    batch_size=self.args.per_device_train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
654
                    seed=seed,
655
                )
656
            else:
657
                return DistributedSampler(
658
659
660
                    self.train_dataset,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
661
                    seed=seed,
662
                )
663
664
665

    def get_train_dataloader(self) -> DataLoader:
        """
666
        Returns the training [`~torch.utils.data.DataLoader`].
667

668
669
        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.
670
671
672
673
674

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
675

676
677
678
679
        train_dataset = self.train_dataset
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")

680
        if isinstance(train_dataset, torch.utils.data.IterableDataset):
681
682
            if self.args.world_size > 1:
                train_dataset = IterableDatasetShard(
683
                    train_dataset,
684
685
686
687
688
                    batch_size=self.args.train_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
689

690
691
            return DataLoader(
                train_dataset,
692
                batch_size=self.args.per_device_train_batch_size,
693
694
695
696
697
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

698
699
700
        train_sampler = self._get_train_sampler()

        return DataLoader(
701
            train_dataset,
Julien Chaumond's avatar
Julien Chaumond committed
702
703
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
704
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
705
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
706
            num_workers=self.args.dataloader_num_workers,
707
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
708
709
        )

710
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        # Deprecated code
        if self.args.use_legacy_prediction_loop:
            if is_torch_tpu_available():
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            elif is_sagemaker_mp_enabled():
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            elif self.args.local_rank != -1:
                return SequentialDistributedSampler(eval_dataset)
            else:
                return SequentialSampler(eval_dataset)

        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            return ShardSampler(
Sylvain Gugger's avatar
Sylvain Gugger committed
733
734
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
735
736
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
Sylvain Gugger's avatar
Sylvain Gugger committed
737
            )
Lysandre Debut's avatar
Lysandre Debut committed
738

Julien Chaumond's avatar
Julien Chaumond committed
739
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
740
        """
741
        Returns the evaluation [`~torch.utils.data.DataLoader`].
742

743
744
        Subclass and override this method if you want to inject some custom behavior.

745
        Args:
746
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
747
748
                If provided, will override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not accepted by
                the `model.forward()` method are automatically removed. It must implement `__len__`.
749
        """
Julien Chaumond's avatar
Julien Chaumond committed
750
751
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
752
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
753

754
755
756
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")

757
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
758
759
760
            if self.args.world_size > 1:
                eval_dataset = IterableDatasetShard(
                    eval_dataset,
761
                    batch_size=self.args.per_device_eval_batch_size,
762
763
764
765
766
767
768
769
770
771
772
773
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                eval_dataset,
                batch_size=self.args.eval_batch_size,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

774
        eval_sampler = self._get_eval_sampler(eval_dataset)
775

776
        return DataLoader(
777
            eval_dataset,
778
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
779
            batch_size=self.args.eval_batch_size,
780
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
781
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
782
            num_workers=self.args.dataloader_num_workers,
783
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
784
785
786
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
787
        """
788
        Returns the test [`~torch.utils.data.DataLoader`].
789

790
791
        Subclass and override this method if you want to inject some custom behavior.

792
        Args:
793
            test_dataset (`torch.utils.data.Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
794
795
                The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
                method are automatically removed. It must implement `__len__`.
796
        """
797
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
798
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
799

800
        if isinstance(test_dataset, torch.utils.data.IterableDataset):
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
            if self.args.world_size > 1:
                test_dataset = IterableDatasetShard(
                    test_dataset,
                    batch_size=self.args.eval_batch_size,
                    drop_last=self.args.dataloader_drop_last,
                    num_processes=self.args.world_size,
                    process_index=self.args.process_index,
                )
            return DataLoader(
                test_dataset,
                batch_size=self.args.eval_batch_size,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )

817
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
818

819
820
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
821
            test_dataset,
822
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
823
            batch_size=self.args.eval_batch_size,
824
            collate_fn=self.data_collator,
825
            drop_last=self.args.dataloader_drop_last,
826
            pin_memory=self.args.dataloader_pin_memory,
Julien Chaumond's avatar
Julien Chaumond committed
827
        )
Lysandre Debut's avatar
Lysandre Debut committed
828

829
    def create_optimizer_and_scheduler(self, num_training_steps: int):
830
831
832
        """
        Setup the optimizer and the learning rate scheduler.

833
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Sylvain Gugger's avatar
Sylvain Gugger committed
834
835
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
836
837
        """
        self.create_optimizer()
838
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
839
840
841
842
843

    def create_optimizer(self):
        """
        Setup the optimizer.

844
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
845
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
846
        """
847
848
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

849
        if self.optimizer is None:
850
            decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
851
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
852
853
            optimizer_grouped_parameters = [
                {
854
                    "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters],
855
856
857
                    "weight_decay": self.args.weight_decay,
                },
                {
858
                    "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters],
859
860
861
                    "weight_decay": 0.0,
                },
            ]
862
863
864

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

865
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
866
867
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
Sylvain Gugger's avatar
Sylvain Gugger committed
868
869
                    optim=optimizer_cls,
                    **optimizer_kwargs,
870
871
                )
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
872
                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
873
874
875
876
877
                if optimizer_cls.__name__ == "Adam8bit":
                    import bitsandbytes

                    manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

878
                    for module in opt_model.modules():
879
880
881
                        if isinstance(module, nn.Embedding):
                            manager.register_module_override(module, "weight", {"optim_bits": 32})
                            logger.debug(f"bitsandbytes: will optimize {module} in fp32")
Sylvain Gugger's avatar
Sylvain Gugger committed
882

Sylvain Gugger's avatar
Sylvain Gugger committed
883
884
885
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

886
887
        return self.optimizer

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
    @staticmethod
    def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
        """
        Returns the optimizer class and optimizer parameters based on the training arguments.

        Args:
            args (`transformers.training_args.TrainingArguments`):
                The training arguments for the training session.

        """
        optimizer_kwargs = {"lr": args.learning_rate}
        adam_kwargs = {
            "betas": (args.adam_beta1, args.adam_beta2),
            "eps": args.adam_epsilon,
        }
        if args.optim == OptimizerNames.ADAFACTOR:
            optimizer_cls = Adafactor
            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
        elif args.optim == OptimizerNames.ADAMW_HF:
            from .optimization import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
        elif args.optim == OptimizerNames.ADAMW_TORCH:
            from torch.optim import AdamW

            optimizer_cls = AdamW
            optimizer_kwargs.update(adam_kwargs)
916
917
918
919
920
921
922
923
        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
            try:
                from torch_xla.amp.syncfree import AdamW

                optimizer_cls = AdamW
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
924
925
926
927
928
929
930
931
        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
            try:
                from apex.optimizers import FusedAdam

                optimizer_cls = FusedAdam
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
932
933
934
935
936
937
938
939
        elif args.optim == OptimizerNames.ADAMW_BNB:
            try:
                from bitsandbytes.optim import Adam8bit

                optimizer_cls = Adam8bit
                optimizer_kwargs.update(adam_kwargs)
            except ImportError:
                raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
940
941
942
943
        else:
            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
        return optimizer_cls, optimizer_kwargs

944
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
945
        """
946
947
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.
948
949
950
951

        Args:
            num_training_steps (int): The number of training steps to do.
        """
952
        if self.lr_scheduler is None:
Sylvain Gugger's avatar
Sylvain Gugger committed
953
954
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
955
                optimizer=self.optimizer if optimizer is None else optimizer,
956
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
Sylvain Gugger's avatar
Sylvain Gugger committed
957
                num_training_steps=num_training_steps,
958
            )
959
        return self.lr_scheduler
Julien Chaumond's avatar
Julien Chaumond committed
960

961
    def num_examples(self, dataloader: DataLoader) -> int:
962
        """
963
964
        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
        dataloader.dataset does not exist or has no length, estimates as best it can
965
        """
966
967
968
969
        try:
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
            return len(dataloader) * self.args.per_device_train_batch_size
970

971
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
Patrick von Platen's avatar
Patrick von Platen committed
972
        """HP search setup code"""
973
974
        self._trial = trial

975
976
        if self.hp_search_backend is None or trial is None:
            return
977
978
979
980
981
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            params = trial
            params.pop("wandb", None)
982
983
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
984
985
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            params = trial
986

987
988
        for key, value in params.items():
            if not hasattr(self.args, key):
989
                logger.warning(
990
991
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
992
                continue
993
994
995
996
997
998
999
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)
1000
1001
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
1002
1003
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")
1004
1005
        if self.args.deepspeed:
            # Rebuild the deepspeed config to reflect the updated training parameters
1006
            from transformers.deepspeed import HfTrainerDeepSpeedConfig
1007

1008
1009
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
1010
1011
1012
1013
1014
1015

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
1016
        self.objective = self.compute_objective(metrics.copy())
1017
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
1018
1019
            import optuna

1020
1021
            trial.report(self.objective, epoch)
            if trial.should_prune():
1022
                self.callback_handler.on_train_end(self.args, self.state, self.control)
1023
1024
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
1025
1026
            from ray import tune

1027
            if self.control.should_save:
1028
                self._tune_save_checkpoint()
1029
1030
            tune.report(objective=self.objective, **metrics)

1031
    def _tune_save_checkpoint(self):
1032
1033
        from ray import tune

1034
1035
        if not self.use_tune_checkpoints:
            return
1036
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
1037
            output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
Sylvain Gugger's avatar
Sylvain Gugger committed
1038
            self.save_model(output_dir, _internal_call=True)
1039
            if self.args.should_save:
1040
1041
1042
                self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1043

1044
    def call_model_init(self, trial=None):
1045
        model_init_argcount = number_of_arguments(self.model_init)
1046
1047
1048
1049
1050
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
1051
1052
1053
1054
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
1055
1056
1057

        return model

1058
    def _wrap_model(self, model, training=True):
Sylvain Gugger's avatar
Sylvain Gugger committed
1059
1060
1061
1062
1063
1064
        if is_sagemaker_mp_enabled():
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

1065
1066
        # already initialized its own DDP and AMP
        if self.deepspeed:
1067
            return self.deepspeed
1068

1069
1070
1071
1072
        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

1073
1074
1075
1076
1077
1078
        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
1079
            model = nn.DataParallel(model)
1080
1081
1082
1083
1084
1085
1086

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
1087
1088
1089
1090
1091
        if self.sharded_ddp is not None:
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
1092
                mixed_precision = self.args.fp16 or self.args.bf16
1093
1094
1095
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
1096
1097
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
1098
                self.model = model = FullyShardedDDP(
1099
1100
1101
1102
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
1103
1104
                ).to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1105
        elif is_sagemaker_dp_enabled():
Lai Wei's avatar
Lai Wei committed
1106
1107
1108
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
1109
        elif self.args.local_rank != -1:
1110
            kwargs = {}
1111
            if self.args.ddp_find_unused_parameters is not None:
1112
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1113
1114
1115
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1116
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
1117
            else:
1118
1119
1120
1121
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
1122
            model = nn.parallel.DistributedDataParallel(
1123
                model,
1124
1125
                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
1126
                **kwargs,
1127
1128
1129
1130
            )

        return model

1131
1132
    def train(
        self,
1133
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
1134
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
1135
        ignore_keys_for_eval: Optional[List[str]] = None,
1136
        **kwargs,
1137
    ):
Julien Chaumond's avatar
Julien Chaumond committed
1138
1139
1140
1141
        """
        Main training entry point.

        Args:
1142
            resume_from_checkpoint (`str` or `bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1143
                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
1144
                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
Sylvain Gugger's avatar
Sylvain Gugger committed
1145
                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
1146
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
1147
                The trial run or the hyperparameter dictionary for hyperparameter search.
1148
            ignore_keys_for_eval (`List[str]`, *optional*)
1149
1150
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions for evaluation during the training.
1151
1152
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
Julien Chaumond's avatar
Julien Chaumond committed
1153
        """
1154
1155
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None
1156
1157
1158
1159

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

1160
1161
        args = self.args

1162
1163
        self.is_in_train = True

1164
1165
        # do_train is not a reliable argument, as it might not be set and .train() still called, so
        # the following is a workaround:
1166
        if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
Sylvain Gugger's avatar
Sylvain Gugger committed
1167
            self._move_model_to_device(self.model, args.device)
1168

1169
1170
1171
1172
1173
1174
1175
1176
1177
        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
1178
1179
1180
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

1181
        # Model re-init
1182
        model_reloaded = False
1183
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1184
            # Seed must be set before instantiating the model when using model_init.
1185
            set_seed(args.seed)
1186
1187
            self.model = self.call_model_init(trial)
            model_reloaded = True
Sylvain Gugger's avatar
Sylvain Gugger committed
1188
1189
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
1190

1191
        # Load potential model checkpoint
1192
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
1193
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
1194
            if resume_from_checkpoint is None:
1195
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
1196

1197
        if resume_from_checkpoint is not None:
1198
            self._load_from_checkpoint(resume_from_checkpoint)
1199

1200
1201
        # If model was re-initialized, put it on the right device and update self.model_wrapped
        if model_reloaded:
1202
            if self.place_model_on_device:
Sylvain Gugger's avatar
Sylvain Gugger committed
1203
                self._move_model_to_device(self.model, args.device)
1204
1205
            self.model_wrapped = self.model

1206
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
1207
        train_dataloader = self.get_train_dataloader()
1208
1209
1210
1211
1212

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
1213
        total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
1214
1215
1216
1217
1218

        len_dataloader = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1219
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1220
            num_examples = self.num_examples(train_dataloader)
1221
1222
1223
1224
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
1225
                )
1226
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1227
1228
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
1229
            else:
1230
1231
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
1232
1233
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
1234
            max_steps = args.max_steps
1235
1236
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
1237
            num_update_steps_per_epoch = max_steps
1238
            num_examples = total_train_batch_size * args.max_steps
1239
            num_train_samples = args.max_steps * total_train_batch_size
1240
1241
1242
1243
        else:
            raise ValueError(
                f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}"
            )
Julien Chaumond's avatar
Julien Chaumond committed
1244

1245
        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
1246
1247
1248
1249
1250
1251
1252
1253
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP (torch.distributed.launch)."
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
1254

1255
1256
1257
        delay_optimizer_creation = (
            self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
        )
1258
        if args.deepspeed:
1259
            deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
1260
1261
                self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
            )
1262
1263
1264
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
1265
1266
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
1267
        elif not delay_optimizer_creation:
1268
1269
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1270
        self.state = TrainerState()
1271
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
1272

1273
1274
1275
1276
        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

1277
        model = self._wrap_model(self.model_wrapped)
Julien Chaumond's avatar
Julien Chaumond committed
1278

1279
1280
1281
1282
        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

1283
1284
1285
        if delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

1286
1287
1288
        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

1289
1290
        # important: at this point:
        # self.model         is the Transformers Model
1291
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
1292

Julien Chaumond's avatar
Julien Chaumond committed
1293
1294
        # Train!
        logger.info("***** Running training *****")
1295
1296
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
1297
        logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
1298
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
1299
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1300
        logger.info(f"  Total optimization steps = {max_steps}")
Julien Chaumond's avatar
Julien Chaumond committed
1301

1302
        self.state.epoch = 0
1303
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
1304
1305
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
1306
        steps_trained_progress_bar = None
1307

Julien Chaumond's avatar
Julien Chaumond committed
1308
        # Check if continuing training from a checkpoint
1309
        if resume_from_checkpoint is not None and os.path.isfile(
1310
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
1311
        ):
1312
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
1313
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
1314
            if not args.ignore_data_skip:
1315
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1316
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
1317
1318
            else:
                steps_trained_in_current_epoch = 0
1319
1320

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
1321
1322
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
1323
            if not args.ignore_data_skip:
1324
1325
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
1326
1327
                    "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
                    "flag to your launch command, but you will resume the training on data already seen by your model."
1328
                )
1329
1330
1331
                if self.is_local_process_zero() and not args.disable_tqdm:
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")
1332

Sylvain Gugger's avatar
Sylvain Gugger committed
1333
1334
1335
1336
1337
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
1338
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
1339
1340
1341
1342
1343
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
1344
1345
1346
1347
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
1348
1349
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
1350

1351
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
1352
        tr_loss = torch.tensor(0.0).to(args.device)
1353
1354
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
1355
        self._globalstep_last_logged = self.state.global_step
Julien Chaumond's avatar
Julien Chaumond committed
1356
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
1357

1358
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1359

1360
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
1361
        if not args.ignore_data_skip:
1362
            for epoch in range(epochs_trained):
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
                if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler:
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)
1375

1376
        for epoch in range(epochs_trained, num_train_epochs):
1377
1378
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)
1379
            elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
1380
                train_dataloader.dataset.set_epoch(epoch)
1381

1382
            if is_torch_tpu_available():
1383
                parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
1384
                epoch_iterator = parallel_loader
1385
            else:
1386
                epoch_iterator = train_dataloader
1387

1388
            # Reset the past mems state at the beginning of each epoch if necessary.
1389
            if args.past_index >= 0:
1390
1391
                self._past = None

1392
            steps_in_epoch = (
1393
1394
1395
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
1396
            )
1397
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1398

1399
1400
1401
            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

1402
            step = -1
Julien Chaumond's avatar
Julien Chaumond committed
1403
1404
1405
1406
1407
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
1408
1409
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
1410
1411
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
Julien Chaumond's avatar
Julien Chaumond committed
1412
                    continue
1413
1414
1415
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None
Julien Chaumond's avatar
Julien Chaumond committed
1416

1417
1418
                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1419

1420
                if (
1421
1422
1423
                    ((step + 1) % args.gradient_accumulation_steps != 0)
                    and args.local_rank != -1
                    and args._no_sync_in_gradient_accumulation
1424
                ):
1425
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
1426
                    with model.no_sync():
1427
                        tr_loss_step = self.training_step(model, inputs)
1428
                else:
1429
1430
                    tr_loss_step = self.training_step(model, inputs)

1431
1432
1433
1434
1435
1436
1437
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
1438
1439
1440
                else:
                    tr_loss += tr_loss_step

1441
                self.current_flos += float(self.floating_point_ops(inputs))
Julien Chaumond's avatar
Julien Chaumond committed
1442

1443
1444
1445
1446
                # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
                if self.deepspeed:
                    self.deepspeed.step()

1447
                if (step + 1) % args.gradient_accumulation_steps == 0 or (
Julien Chaumond's avatar
Julien Chaumond committed
1448
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
1449
                    steps_in_epoch <= args.gradient_accumulation_steps
1450
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
1451
                ):
1452
                    # Gradient clipping
1453
                    if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
1454
1455
                        # deepspeed does its own clipping

1456
                        if self.do_grad_scaling:
1457
1458
1459
1460
                            # Reduce gradients first for XLA
                            if is_torch_tpu_available():
                                gradients = xm._fetch_gradients(self.optimizer)
                                xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
1461
1462
1463
1464
1465
                            # AMP: gradients need unscaling
                            self.scaler.unscale_(self.optimizer)

                        if hasattr(self.optimizer, "clip_grad_norm"):
                            # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
1466
                            self.optimizer.clip_grad_norm(args.max_grad_norm)
1467
1468
                        elif hasattr(model, "clip_grad_norm_"):
                            # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
1469
                            model.clip_grad_norm_(args.max_grad_norm)
1470
1471
                        else:
                            # Revert to normal clipping otherwise, handling Apex or full precision
1472
                            nn.utils.clip_grad_norm_(
1473
                                amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
1474
                                args.max_grad_norm,
1475
1476
1477
                            )

                    # Optimizer step
1478
                    optimizer_was_run = True
Stas Bekman's avatar
Stas Bekman committed
1479
                    if self.deepspeed:
1480
                        pass  # called outside the loop
Stas Bekman's avatar
Stas Bekman committed
1481
                    elif is_torch_tpu_available():
1482
1483
1484
1485
1486
                        if self.do_grad_scaling:
                            self.scaler.step(self.optimizer)
                            self.scaler.update()
                        else:
                            xm.optimizer_step(self.optimizer)
1487
                    elif self.do_grad_scaling:
1488
                        scale_before = self.scaler.get_scale()
1489
                        self.scaler.step(self.optimizer)
1490
                        self.scaler.update()
1491
1492
                        scale_after = self.scaler.get_scale()
                        optimizer_was_run = scale_before <= scale_after
Lysandre Debut's avatar
Lysandre Debut committed
1493
                    else:
1494
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
1495

1496
                    if optimizer_was_run and not self.deepspeed:
1497
1498
                        self.lr_scheduler.step()

Julien Chaumond's avatar
Julien Chaumond committed
1499
                    model.zero_grad()
1500
                    self.state.global_step += 1
1501
                    self.state.epoch = epoch + (step + 1) / steps_in_epoch
1502
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)
Sylvain Gugger's avatar
Sylvain Gugger committed
1503

1504
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
wulu473's avatar
wulu473 committed
1505
1506
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1507

Sylvain Gugger's avatar
Sylvain Gugger committed
1508
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
1509
                    break
1510
1511
1512
1513
1514
1515
1516
            if step < 0:
                logger.warning(
                    f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True
1517

1518
            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
1519
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
1520

1521
            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
1522
1523
1524
1525
1526
1527
1528
1529
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1530
            if self.control.should_training_stop:
1531
                break
Julien Chaumond's avatar
Julien Chaumond committed
1532

1533
        if args.past_index and hasattr(self, "_past"):
1534
1535
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1536
1537

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
1538
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
1539
1540
1541
            # Wait for everyone to get here so we are sur the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
1542
            elif args.local_rank != -1:
1543
1544
                dist.barrier()

1545
            self._load_best_model()
1546

1547
1548
1549
1550
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

1551
        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
1552
1553
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
1554
        metrics["train_loss"] = train_loss
Sylvain Gugger's avatar
Sylvain Gugger committed
1555

1556
        self.is_in_train = False
1557

1558
1559
        self._memory_tracker.stop_and_update_metrics(metrics)

1560
1561
1562
1563
1564
        self.log(metrics)

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        return TrainOutput(self.state.global_step, train_loss, metrics)
1565

1566
    def _load_from_checkpoint(self, resume_from_checkpoint):
1567
1568
1569
        if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
            os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        ):
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

        logger.info(f"Loading model from {resume_from_checkpoint}).")

        if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
            config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )

        if self.args.deepspeed:
            # will be resumed in deepspeed_init
            pass
1587
        elif os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
1588
1589
1590
            # We load the model state dict on the CPU to avoid an OOM error.
            state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
            # If the model is on the GPU, it still works!
1591
1592
            load_result = self.model.load_state_dict(state_dict, strict=False)
            self._issue_warnings_after_load(load_result)
1593
1594
1595

            # release memory
            del state_dict
1596
1597
1598
1599
        else:
            # We load the sharded checkpoint
            load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False)
            self._issue_warnings_after_load(load_result)
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
        if os.path.exists(best_model_path):
            if self.deepspeed:
                # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
                deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
                self.optimizer = optimizer
                self.lr_scheduler = lr_scheduler
                self.deepspeed.load_checkpoint(
                    self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
                )
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(best_model_path, map_location="cpu")
                # If the model is on the GPU, it still works!
1621
1622
1623
1624
1625
1626
                load_result = self.model.load_state_dict(state_dict, strict=False)
                self._issue_warnings_after_load(load_result)
        elif os.path.exists(best_model_path, os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
            # Best model is a sharded checkpoint
            load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False)
            self._issue_warnings_after_load(load_result)
1627
1628
1629
1630
1631
1632
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

1633
    def _issue_warnings_after_load(self, load_result):
1634
1635

        if len(load_result.missing_keys) != 0:
1636
1637
1638
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
1639
1640
                self.model.tie_weights()
            else:
1641
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
1642
        if len(load_result.unexpected_keys) != 0:
1643
1644
1645
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
1646

1647
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
Sylvain Gugger's avatar
Sylvain Gugger committed
1648
        if self.control.should_log:
1649
1650
1651
            if is_torch_tpu_available():
                xm.mark_step()

Sylvain Gugger's avatar
Sylvain Gugger committed
1652
            logs: Dict[str, float] = {}
1653
1654
1655
1656

            # all_gather + mean() to get average loss over all processes
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

1657
1658
1659
            # reset tr_loss to zero
            tr_loss -= tr_loss

1660
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
1661
            logs["learning_rate"] = self._get_learning_rate()
1662

1663
            self._total_loss_scalar += tr_loss_scalar
1664
            self._globalstep_last_logged = self.state.global_step
Teven's avatar
Teven committed
1665
            self.store_flos()
Sylvain Gugger's avatar
Sylvain Gugger committed
1666
1667
1668
1669
1670

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
1671
            metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
Sylvain Gugger's avatar
Sylvain Gugger committed
1672
            self._report_to_hp_search(trial, epoch, metrics)
1673

Sylvain Gugger's avatar
Sylvain Gugger committed
1674
1675
1676
1677
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

        local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
        if local_rank != -1:
            rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
            if not os.path.isfile(os.path.join(checkpoint, rng_file)):
                logger.info(
                    f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that "
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
1694
            if not os.path.isfile(rng_file):
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        if torch.cuda.is_available():
            if self.args.local_rank != -1:
                torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
            else:
1709
1710
1711
                try:
                    torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
                except Exception as e:
1712
                    logger.info(
1713
1714
1715
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
1716
1717
1718
        if is_torch_tpu_available():
            xm.set_rng_state(checkpoint_rng_state["xla"])

Sylvain Gugger's avatar
Sylvain Gugger committed
1719
    def _save_checkpoint(self, model, trial, metrics=None):
1720
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
1721
        # want to save except FullyShardedDDP.
1722
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
1723

1724
        # Save model checkpoint
1725
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1726

1727
        if self.hp_search_backend is not None and trial is not None:
1728
1729
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
1730
            elif self.hp_search_backend == HPSearchBackend.RAY:
1731
1732
1733
                from ray import tune

                run_id = tune.get_trial_id()
1734
1735
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
1736
1737
1738
1739
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb

                run_id = wandb.run.id
1740
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
1741
            run_dir = os.path.join(self.args.output_dir, run_name)
1742
        else:
1743
            run_dir = self.args.output_dir
1744
            self.store_flos()
1745

1746
        output_dir = os.path.join(run_dir, checkpoint_folder)
Sylvain Gugger's avatar
Sylvain Gugger committed
1747
        self.save_model(output_dir, _internal_call=True)
1748
        if self.deepspeed:
1749
            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
1750
            # config `stage3_gather_16bit_weights_on_model_save` is True
1751
            self.deepspeed.save_checkpoint(output_dir)
1752
1753

        # Save optimizer and scheduler
1754
        if self.sharded_ddp == ShardedDDPOption.SIMPLE:
1755
            self.optimizer.consolidate_state_dict()
1756

1757
1758
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
1759
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
1760
            with warnings.catch_warnings(record=True) as caught_warnings:
1761
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1762
                reissue_pt_warnings(caught_warnings)
Sylvain Gugger's avatar
Sylvain Gugger committed
1763
        elif is_sagemaker_mp_enabled():
1764
1765
            if smp.rdp_rank() == 0:
                # Consolidate the state dict on all processed of rdp_rank 0
1766
1767
                opt_state_dict = self.optimizer.state_dict()
                # Save it and the scheduler on the main process
1768
                if self.args.should_save:
1769
                    torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
1770
                    with warnings.catch_warnings(record=True) as caught_warnings:
1771
                        torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1772
                    reissue_pt_warnings(caught_warnings)
1773
                    if self.do_grad_scaling:
1774
                        torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
1775
        elif self.args.should_save and not self.deepspeed:
1776
            # deepspeed.save_checkpoint above saves model/optim/sched
1777
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
1778
            with warnings.catch_warnings(record=True) as caught_warnings:
1779
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
1780
            reissue_pt_warnings(caught_warnings)
1781
            if self.do_grad_scaling:
1782
                torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
1783
1784

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
1785
        if metrics is not None and self.args.metric_for_best_model is not None:
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
1801
        if self.args.should_save:
1802
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
1803

1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
        }
        if torch.cuda.is_available():
            if self.args.local_rank == -1:
                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
            else:
                rng_states["cuda"] = torch.cuda.random.get_rng_state()

        if is_torch_tpu_available():
            rng_states["xla"] = xm.get_rng_state()

1820
1821
1822
        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
1823
1824
1825
1826
1827
1828
        local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
        if local_rank == -1:
            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
        else:
            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))

1829
1830
1831
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

1832
        # Maybe delete some older checkpoints.
1833
        if self.args.should_save:
1834
1835
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

1836
    def _load_optimizer_and_scheduler(self, checkpoint):
Sylvain Gugger's avatar
Sylvain Gugger committed
1837
        """If optimizer and scheduler states exist, load them."""
1838
        if checkpoint is None:
1839
1840
            return

1841
        if self.deepspeed:
1842
            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
1843
1844
            return

1845
1846
        if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
            os.path.join(checkpoint, SCHEDULER_NAME)
Sylvain Gugger's avatar
Sylvain Gugger committed
1847
1848
1849
1850
        ):
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
1851
                optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
1852
                with warnings.catch_warnings(record=True) as caught_warnings:
1853
                    lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
Sylvain Gugger's avatar
Sylvain Gugger committed
1854
1855
1856
1857
1858
1859
1860
1861
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1862
                map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
Sylvain Gugger's avatar
Sylvain Gugger committed
1863
                self.optimizer.load_state_dict(
1864
                    torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
Sylvain Gugger's avatar
Sylvain Gugger committed
1865
1866
                )
                with warnings.catch_warnings(record=True) as caught_warnings:
1867
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
1868
                reissue_pt_warnings(caught_warnings)
1869
                if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
1870
                    self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
Sylvain Gugger's avatar
Sylvain Gugger committed
1871

1872
1873
1874
1875
1876
1877
1878
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
1879
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
1880
        **kwargs,
1881
1882
    ) -> BestRun:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1883
1884
1885
        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
1886

1887
        <Tip warning={true}>
Sylvain Gugger's avatar
Sylvain Gugger committed
1888

Sylvain Gugger's avatar
Sylvain Gugger committed
1889
1890
1891
1892
        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
        optimizer/scheduler.
1893
1894

        </Tip>
Sylvain Gugger's avatar
Sylvain Gugger committed
1895

1896
        Args:
1897
            hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
1898
                A function that defines the hyperparameter search space. Will default to
Sylvain Gugger's avatar
Sylvain Gugger committed
1899
                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
1900
1901
                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
            compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1902
1903
                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
                method. Will default to [`~trainer_utils.default_compute_objective`].
1904
            n_trials (`int`, *optional*, defaults to 100):
1905
                The number of trial runs to test.
1906
            direction(`str`, *optional*, defaults to `"minimize"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1907
1908
                Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
                `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
1909
            backend(`str` or [`~training_utils.HPSearchBackend`], *optional*):
1910
1911
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
                on which one is installed. If all are installed, will default to optuna.
1912
            kwargs:
Sylvain Gugger's avatar
Sylvain Gugger committed
1913
1914
                Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
                information see:
1915

Sylvain Gugger's avatar
Sylvain Gugger committed
1916
1917
                - the documentation of
                  [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
1918
1919
                - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
                - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
1920
1921

        Returns:
1922
            [`trainer_utils.BestRun`]: All the information about the best run.
1923
1924
1925
1926
1927
1928
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
1929
1930
                    "To install optuna run `pip install optuna`. "
                    "To install ray run `pip install ray[tune]`. "
1931
                    "To install sigopt run `pip install sigopt`."
1932
1933
1934
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
1935
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
1936
        if backend == HPSearchBackend.RAY and not is_ray_tune_available():
1937
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1938
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
1939
            )
1940
1941
        if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
            raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
1942
1943
        if backend == HPSearchBackend.WANDB and not is_wandb_available():
            raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
1944
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
1945
1946
1947
1948
1949
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

1950
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
1951
        self.hp_name = hp_name
1952
1953
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

1954
1955
1956
1957
        backend_dict = {
            HPSearchBackend.OPTUNA: run_hp_search_optuna,
            HPSearchBackend.RAY: run_hp_search_ray,
            HPSearchBackend.SIGOPT: run_hp_search_sigopt,
1958
            HPSearchBackend.WANDB: run_hp_search_wandb,
1959
1960
        }
        best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
1961
1962
1963
1964

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
1965
    def log(self, logs: Dict[str, float]) -> None:
1966
        """
1967
        Log `logs` on the various objects watching training.
1968
1969
1970
1971

        Subclass and override this method to inject custom behavior.

        Args:
1972
            logs (`Dict[str, float]`):
1973
1974
                The values to log.
        """
1975
        if self.state.epoch is not None:
1976
            logs["epoch"] = round(self.state.epoch, 2)
1977

1978
1979
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
1980
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
Julien Chaumond's avatar
Julien Chaumond committed
1981

1982
1983
    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
1984
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
1985
        """
1986
1987
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
            kwargs = dict(device=self.args.device)
            if self.deepspeed and data.dtype != torch.int64:
                # NLP models inputs are int64 and those get adjusted to the right dtype of the
                # embedding. Other models such as wav2vec2's inputs are already float and thus
                # may need special handling to match the dtypes of the model
                kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
            return data.to(**kwargs)
        return data

sgugger's avatar
Fix CI  
sgugger committed
2000
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
2001
        """
2002
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
2003
2004
        handling potential state.
        """
2005
        inputs = self._prepare_input(inputs)
2006
2007
2008
2009
2010
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
2011
2012
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
2013

2014
2015
        return inputs

2016
2017
    def autocast_smart_context_manager(self):
        """
2018
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
        arguments, depending on the situation.
        """
        if self.use_amp:
            if version.parse(torch.__version__) >= version.parse("1.10"):
                ctx_manager = autocast(dtype=self.amp_dtype)
            else:
                ctx_manager = autocast()
        else:
            ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

        return ctx_manager

2031
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
2032
        """
2033
        Perform a training step on a batch of inputs.
2034
2035
2036
2037

        Subclass and override to inject custom behavior.

        Args:
2038
            model (`nn.Module`):
2039
                The model to train.
2040
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2041
2042
2043
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2044
                argument `labels`. Check your model's documentation for all accepted arguments.
2045
2046

        Return:
2047
            `torch.Tensor`: The tensor with training loss on this batch.
2048
2049
        """
        model.train()
2050
        inputs = self._prepare_inputs(inputs)
2051

Sylvain Gugger's avatar
Sylvain Gugger committed
2052
        if is_sagemaker_mp_enabled():
2053
            scaler = self.scaler if self.do_grad_scaling else None
2054
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
Sylvain Gugger's avatar
Sylvain Gugger committed
2055
2056
            return loss_mb.reduce_mean().detach().to(self.args.device)

2057
        with self.autocast_smart_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2058
            loss = self.compute_loss(model, inputs)
2059

Julien Chaumond's avatar
Julien Chaumond committed
2060
2061
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
2062

2063
2064
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
Julien Chaumond's avatar
Julien Chaumond committed
2065
2066
            loss = loss / self.args.gradient_accumulation_steps

2067
        if self.do_grad_scaling:
2068
            self.scaler.scale(loss).backward()
2069
        elif self.use_apex:
2070
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
2071
                scaled_loss.backward()
2072
        elif self.deepspeed:
2073
2074
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
Julien Chaumond's avatar
Julien Chaumond committed
2075
2076
2077
        else:
            loss.backward()

2078
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
2079

2080
    def compute_loss(self, model, inputs, return_outputs=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
2081
2082
2083
2084
2085
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
2086
2087
2088
2089
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
Sylvain Gugger's avatar
Sylvain Gugger committed
2090
2091
        outputs = model(**inputs)
        # Save past state if it exists
2092
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
2093
2094
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
Sylvain Gugger's avatar
Sylvain Gugger committed
2095

2096
        if labels is not None:
2097
            loss = self.label_smoother(outputs, labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
2098
2099
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
2100
2101
2102
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss
Sylvain Gugger's avatar
Sylvain Gugger committed
2103

2104
2105
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2106
2107
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
2108
        """
2109
        return self.args.local_process_index == 0
Lysandre Debut's avatar
Lysandre Debut committed
2110

2111
2112
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2113
        Whether or not this process is the global main process (when training in a distributed fashion on several
2114
        machines, this is only going to be `True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
2115
        """
2116
2117
2118
        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
        # process index.
        if is_sagemaker_mp_enabled():
Sylvain Gugger's avatar
Sylvain Gugger committed
2119
            return smp.rank() == 0
Lysandre Debut's avatar
Lysandre Debut committed
2120
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2121
            return self.args.process_index == 0
Julien Chaumond's avatar
Julien Chaumond committed
2122

Sylvain Gugger's avatar
Sylvain Gugger committed
2123
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
Julien Chaumond's avatar
Julien Chaumond committed
2124
        """
2125
        Will save the model, so you can reload it using `from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
2126

2127
        Will only save from the main process.
Julien Chaumond's avatar
Julien Chaumond committed
2128
        """
2129
2130
2131
2132

        if output_dir is None:
            output_dir = self.args.output_dir

2133
        if is_torch_tpu_available():
2134
            self._save_tpu(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
2135
2136
2137
        elif is_sagemaker_mp_enabled():
            # Calling the state_dict needs to be done on the wrapped model and on all processes.
            state_dict = self.model_wrapped.state_dict()
2138
            if self.args.should_save:
Sylvain Gugger's avatar
Sylvain Gugger committed
2139
                self._save(output_dir, state_dict=state_dict)
2140
2141
2142
2143
        elif (
            ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
        ):
            state_dict = self.model.state_dict()
2144

2145
            if self.args.should_save:
2146
                self._save(output_dir, state_dict=state_dict)
2147
2148
2149
        elif self.deepspeed:

            # this takes care of everything as long as we aren't under zero3
2150
            if self.args.should_save:
2151
2152
2153
2154
2155
2156
2157
                self._save(output_dir)

            if is_deepspeed_zero3_enabled():
                # It's too complicated to try to override different places where the weights dump gets
                # saved, so since under zero3 the file is bogus, simply delete it. The user should
                # either user deepspeed checkpoint to resume or to recover full weights use
                # zero_to_fp32.py stored in the checkpoint.
2158
                if self.args.should_save:
2159
2160
2161
2162
2163
                    file = os.path.join(output_dir, WEIGHTS_NAME)
                    if os.path.isfile(file):
                        # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
                        os.remove(file)

2164
                # now save the real model if stage3_gather_16bit_weights_on_model_save=True
2165
2166
                # if false it will not be saved.
                # This must be called on all ranks
2167
                if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
2168
                    logger.warning(
2169
                        "deepspeed.save_16bit_model didn't save the model, since stage3_gather_16bit_weights_on_model_save=false. "
2170
2171
2172
                        "Saving the full checkpoint instead, use zero_to_fp32.py to recover weights"
                    )
                    self.deepspeed.save_checkpoint(output_dir)
2173

2174
        elif self.args.should_save:
2175
            self._save(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2176

Sylvain Gugger's avatar
Sylvain Gugger committed
2177
2178
2179
2180
        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

2181
2182
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
2183
        logger.info(f"Saving model checkpoint to {output_dir}")
2184
2185
2186

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
2187
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2188
2189
2190
2191

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
2192
        if not isinstance(self.model, PreTrainedModel):
2193
2194
2195
            if isinstance(unwrap_model(self.model), PreTrainedModel):
                unwrap_model(self.model).save_pretrained(
                    output_dir,
2196
                    is_main_process=self.args.should_save,
2197
2198
2199
                    state_dict=self.model.state_dict(),
                    save_function=xm.save,
                )
2200
2201
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2202
2203
                state_dict = self.model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2204
        else:
2205
            self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
2206
        if self.tokenizer is not None and self.args.should_save:
2207
            self.tokenizer.save_pretrained(output_dir)
2208

2209
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
2210
        # If we are executing this function, we are the process zero, so we don't check for that.
Julien Chaumond's avatar
Julien Chaumond committed
2211
2212
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
2213
        logger.info(f"Saving model checkpoint to {output_dir}")
Julien Chaumond's avatar
Julien Chaumond committed
2214
2215
2216
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
2217
            if isinstance(unwrap_model(self.model), PreTrainedModel):
2218
2219
2220
                if state_dict is None:
                    state_dict = self.model.state_dict()
                unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
2221
2222
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
2223
2224
                if state_dict is None:
                    state_dict = self.model.state_dict()
2225
                torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
2226
        else:
2227
            self.model.save_pretrained(output_dir, state_dict=state_dict)
2228
        if self.tokenizer is not None:
2229
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2230
2231

        # Good practice: save your training arguments together with the trained model
2232
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2233

2234
    def store_flos(self):
2235
        # Storing the number of floating-point operations that went into the model
2236
        if self.args.local_rank != -1:
2237
2238
2239
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
2240
2241
            self.current_flos = 0
        else:
Teven's avatar
Teven committed
2242
            self.state.total_flos += self.current_flos
2243
            self.current_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
2244

2245
2246
2247
    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
Julien Chaumond's avatar
Julien Chaumond committed
2248
2249
        ordering_and_checkpoint_path = []

2250
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
Julien Chaumond's avatar
Julien Chaumond committed
2251
2252
2253
2254
2255

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
2256
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
2257
                if regex_match is not None and regex_match.groups() is not None:
Julien Chaumond's avatar
Julien Chaumond committed
2258
2259
2260
2261
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
2262
2263
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
2264
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
2265
2266
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
Julien Chaumond's avatar
Julien Chaumond committed
2267
2268
        return checkpoints_sorted

2269
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
2270
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
2271
2272
2273
            return

        # Check if we should delete older checkpoint(s)
2274
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
2275
2276
2277
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

2278
        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
Julien Chaumond's avatar
Julien Chaumond committed
2289
2290
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
2291
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
Julien Chaumond's avatar
Julien Chaumond committed
2292
2293
            shutil.rmtree(checkpoint)

2294
    def evaluate(
2295
2296
2297
2298
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
2299
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
2300
        """
2301
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2302

Sylvain Gugger's avatar
Sylvain Gugger committed
2303
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
2304
        (pass it to the init `compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
2305

2306
2307
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
2308
        Args:
2309
            eval_dataset (`Dataset`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
2310
2311
2312
                Pass a dataset if you wish to override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not
                accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
                method.
2313
            ignore_keys (`Lst[str]`, *optional*):
2314
2315
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2316
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
2317
2318
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
2319

Julien Chaumond's avatar
Julien Chaumond committed
2320
        Returns:
2321
2322
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
2323
        """
2324
2325
2326
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2327
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
2328
        start_time = time.time()
Julien Chaumond's avatar
Julien Chaumond committed
2329

2330
2331
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2332
2333
2334
2335
2336
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
2337
            ignore_keys=ignore_keys,
2338
            metric_key_prefix=metric_key_prefix,
2339
        )
Lysandre Debut's avatar
Lysandre Debut committed
2340

2341
2342
2343
2344
2345
2346
2347
2348
2349
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2350

2351
        self.log(output.metrics)
2352

2353
        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
2354
2355
2356
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
2357
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
2358
2359
2360

        self._memory_tracker.stop_and_update_metrics(output.metrics)

Julien Chaumond's avatar
Julien Chaumond committed
2361
2362
        return output.metrics

2363
    def predict(
Bhadresh Savani's avatar
Bhadresh Savani committed
2364
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
2365
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2366
        """
2367
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
2368

Sylvain Gugger's avatar
Sylvain Gugger committed
2369
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
2370
        will also return metrics, like in `evaluate()`.
2371
2372

        Args:
2373
2374
2375
2376
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
            ignore_keys (`Lst[str]`, *optional*):
2377
2378
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2379
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
2380
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
Bhadresh Savani's avatar
Bhadresh Savani committed
2381
                "test_bleu" if the prefix is "test" (default)
2382

2383
2384
        <Tip>

Sylvain Gugger's avatar
Sylvain Gugger committed
2385
2386
2387
        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.
2388

2389
        </Tip>
2390

2391
        Returns: *NamedTuple* A namedtuple with the following keys:
Sylvain Gugger's avatar
Sylvain Gugger committed
2392

2393
2394
            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
Sylvain Gugger's avatar
Sylvain Gugger committed
2395
2396
            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
              labels).
Julien Chaumond's avatar
Julien Chaumond committed
2397
        """
2398
2399
2400
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

Julien Chaumond's avatar
Julien Chaumond committed
2401
        test_dataloader = self.get_test_dataloader(test_dataset)
2402
        start_time = time.time()
2403

2404
2405
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
2406
2407
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
2408
2409
2410
2411
2412
2413
2414
2415
2416
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
2417
2418
2419

        self._memory_tracker.stop_and_update_metrics(output.metrics)

2420
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
2421

2422
    def evaluation_loop(
2423
2424
2425
2426
2427
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
2428
        metric_key_prefix: str = "eval",
2429
    ) -> EvalLoopOutput:
Julien Chaumond's avatar
Julien Chaumond committed
2430
        """
2431
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
2432
2433
2434

        Works both with or without labels.
        """
2435
2436
2437
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
Julien Chaumond's avatar
Julien Chaumond committed
2438

2439
        # if eval is called w/o train init deepspeed here
2440
        if args.deepspeed and not self.deepspeed:
2441
2442
2443

            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
2444
2445
2446
            deepspeed_engine, _, _ = deepspeed_init(
                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
            )
2447
2448
2449
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
2450

2451
        model = self._wrap_model(self.model, training=False)
Julien Chaumond's avatar
Julien Chaumond committed
2452

2453
2454
2455
2456
2457
2458
2459
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
2460

2461
        batch_size = self.args.eval_batch_size
2462

2463
        logger.info(f"***** Running {description} *****")
2464
        if has_length(dataloader):
2465
2466
2467
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
2468
        logger.info(f"  Batch size = {batch_size}")
2469

Julien Chaumond's avatar
Julien Chaumond committed
2470
2471
        model.eval()

2472
2473
        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
2474
        eval_dataset = getattr(dataloader, "dataset", None)
2475

2476
        if is_torch_tpu_available():
2477
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
2478

2479
        if args.past_index >= 0:
2480
            self._past = None
2481

2482
2483
2484
2485
2486
        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
2487
2488
        inputs_host = None

2489
2490
2491
2492
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
2493
        all_inputs = None
2494
2495
2496
2497
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
2498
        for step, inputs in enumerate(dataloader):
2499
2500
2501
2502
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
2503
2504
2505
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size
2506
2507

            # Prediction step
2508
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
2509
            inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
2510

2511
2512
2513
            if is_torch_tpu_available():
                xm.mark_step()

2514
            # Update containers on host
2515
            if loss is not None:
2516
                losses = self._nested_gather(loss.repeat(batch_size))
2517
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
2518
            if labels is not None:
2519
2520
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
2521
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
2522
2523
2524
2525
2526
2527
2528
2529
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
2530
2531
2532
2533
2534
2535
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
2536
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
2537

2538
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
2539
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
2540
2541
2542
2543
2544
2545
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
2546
2547
2548
2549
2550
2551
2552
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
2553
2554
2555
2556
2557
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )
2558
2559

                # Set back to None to begin a new accumulation
2560
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
2561

2562
        if args.past_index and hasattr(self, "_past"):
2563
2564
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
2565

2566
        # Gather all remaining tensors and put them back on the CPU
2567
2568
2569
2570
2571
2572
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
2573
2574
2575
2576
2577
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )
2578
2579
2580
2581
2582
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
2583
        if has_length(eval_dataset):
2584
            num_samples = len(eval_dataset)
2585
2586
2587
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
2588
2589
            num_samples = eval_dataset.num_examples
        else:
2590
2591
2592
2593
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
2594
2595
2596
2597
2598
2599
2600
2601
2602

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)
2603
2604
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
2605
2606
2607

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
2608
2609
2610
2611
2612
2613
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
Julien Chaumond's avatar
Julien Chaumond committed
2614
2615
        else:
            metrics = {}
2616

2617
2618
2619
        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

2620
2621
        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
2622

2623
        # Prefix all keys with metric_key_prefix + '_'
2624
        for key in list(metrics.keys()):
2625
2626
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
2627

2628
        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
2629

2630
    def _nested_gather(self, tensors, name=None):
2631
2632
2633
2634
2635
2636
2637
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
2638
2639
            if name is None:
                name = "nested_gather"
2640
            tensors = nested_xla_mesh_reduce(tensors, name)
Sylvain Gugger's avatar
Sylvain Gugger committed
2641
2642
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
2643
2644
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)
2645
        return tensors
2646

2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
    # Copied from Accelerate.
    def _pad_across_processes(self, tensor, pad_index=-100):
        """
        Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
        they can safely be gathered.
        """
        if isinstance(tensor, (list, tuple)):
            return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
        elif isinstance(tensor, dict):
            return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
        elif not isinstance(tensor, torch.Tensor):
            raise TypeError(
                f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
            )

        if len(tensor.shape) < 2:
            return tensor
        # Gather all sizes
        size = torch.tensor(tensor.shape, device=tensor.device)[None]
        sizes = self._nested_gather(size).cpu()

        max_size = max(s[1] for s in sizes)
        if tensor.shape[1] == max_size:
            return tensor

        # Then pad to the maximum size
        old_size = tensor.shape
        new_size = list(old_size)
        new_size[1] = max_size
        new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
        new_tensor[:, : old_size[1]] = tensor
        return new_tensor
2679

2680
    def prediction_step(
2681
2682
2683
2684
2685
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
2686
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
2687
        """
Stas Bekman's avatar
Stas Bekman committed
2688
        Perform an evaluation step on `model` using `inputs`.
2689
2690
2691
2692

        Subclass and override to inject custom behavior.

        Args:
2693
            model (`nn.Module`):
2694
                The model to evaluate.
2695
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2696
2697
2698
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2699
2700
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
2701
                Whether or not to return the loss only.
2702
            ignore_keys (`Lst[str]`, *optional*):
2703
2704
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
2705
2706

        Return:
2707
2708
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
2709
        """
2710
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
2711
        inputs = self._prepare_inputs(inputs)
2712
2713
2714
2715
2716
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
2717

2718
2719
2720
2721
2722
2723
2724
2725
        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

2726
        with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
2739
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2740
2741
2742
2743
2744
2745
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
2746
            else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2747
                if has_labels:
2748
                    with self.autocast_smart_context_manager():
2749
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
2750
                    loss = loss.mean().detach()
2751

Sylvain Gugger's avatar
Sylvain Gugger committed
2752
2753
2754
2755
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]
2756
                else:
Sylvain Gugger's avatar
Sylvain Gugger committed
2757
                    loss = None
2758
                    with self.autocast_smart_context_manager():
Sylvain Gugger's avatar
Sylvain Gugger committed
2759
2760
2761
2762
2763
2764
2765
2766
                        outputs = model(**inputs)
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]
2767
2768
2769
2770

        if prediction_loss_only:
            return (loss, None, None)

2771
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
2772
2773
2774
2775
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
2776
2777
2778

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
2779
2780
2781
        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
        operations for every backward + forward pass. If using another model, either implement such a method in the
        model or subclass and override this method.
2782
2783

        Args:
2784
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
2785
2786
2787
                The inputs and targets of the model.

        Returns:
2788
            `int`: The number of floating-point operations.
2789
        """
2790
2791
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
2792
2793
        else:
            return 0
2794

2795
    def init_git_repo(self, at_init: bool = False):
2796
        """
2797
        Initializes a git repo in `self.args.hub_model_id`.
2798
2799
2800
2801
2802
2803

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
                `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
                out.
2804
        """
2805
        if not self.is_world_process_zero():
2806
            return
2807
2808
        use_auth_token = True if self.args.hub_token is None else self.args.hub_token
        if self.args.hub_model_id is None:
2809
            repo_name = Path(self.args.output_dir).absolute().name
2810
2811
        else:
            repo_name = self.args.hub_model_id
2812
2813
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
2814

2815
2816
2817
2818
2819
        try:
            self.repo = Repository(
                self.args.output_dir,
                clone_from=repo_name,
                use_auth_token=use_auth_token,
2820
                private=self.args.hub_private_repo,
2821
2822
            )
        except EnvironmentError:
2823
            if self.args.overwrite_output_dir and at_init:
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
                self.repo = Repository(
                    self.args.output_dir,
                    clone_from=repo_name,
                    use_auth_token=use_auth_token,
                )
            else:
                raise

        self.repo.git_pull()
2835
2836

        # By default, ignore the checkpoint folders
2837
2838
2839
2840
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
2841
2842
2843
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

2844
2845
        self.push_in_progress = None

Sylvain Gugger's avatar
Sylvain Gugger committed
2846
2847
2848
2849
2850
2851
2852
    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
        tags: Optional[str] = None,
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
2853
        tasks: Optional[str] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
2854
2855
2856
2857
        dataset_tags: Optional[Union[str, List[str]]] = None,
        dataset: Optional[Union[str, List[str]]] = None,
        dataset_args: Optional[Union[str, List[str]]] = None,
    ):
2858
2859
2860
        if not self.is_world_process_zero():
            return

Sylvain Gugger's avatar
Sylvain Gugger committed
2861
2862
2863
2864
2865
2866
2867
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
Sylvain Gugger's avatar
Sylvain Gugger committed
2868
            tasks=tasks,
Sylvain Gugger's avatar
Sylvain Gugger committed
2869
2870
2871
2872
2873
2874
2875
2876
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
2911
2912
2913
            _, self.push_in_progress = self.repo.push_to_hub(
                commit_message=commit_message, blocking=False, auto_lfs_prune=True
            )
2914
2915
2916
2917
2918
2919
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
Sylvain Gugger's avatar
Sylvain Gugger committed
2920
        """
2921
        Upload *self.model* and *self.tokenizer* to the 馃 model hub on the repo *self.args.hub_model_id*.
Sylvain Gugger's avatar
Sylvain Gugger committed
2922
2923

        Parameters:
2924
            commit_message (`str`, *optional*, defaults to `"End of training"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
2925
                Message to commit while pushing.
2926
2927
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
Sylvain Gugger's avatar
Sylvain Gugger committed
2928
            kwargs:
2929
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
Sylvain Gugger's avatar
Sylvain Gugger committed
2930
2931

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
2932
2933
            The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
            the commit and an object to track the progress of the commit if `blocking=True`
Sylvain Gugger's avatar
Sylvain Gugger committed
2934
        """
2935
2936
2937
2938
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()
Sylvain Gugger's avatar
Sylvain Gugger committed
2939

2940
        if self.args.should_save:
2941
2942
2943
2944
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]
Sylvain Gugger's avatar
Sylvain Gugger committed
2945

2946
2947
        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
Sylvain Gugger's avatar
Sylvain Gugger committed
2948
        self.save_model(_internal_call=True)
2949
2950
2951
2952
2953

        # Only push from one node.
        if not self.is_world_process_zero():
            return

2954
2955
2956
2957
2958
        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

2959
2960
2961
        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
2962
2963
2964
2965
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
2966
2967
2968
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
2969
2970
2971
2972
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
Sylvain Gugger's avatar
Sylvain Gugger committed
2973

2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
    #
    # Deprecated code
    #

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        """
2987
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
2988
2989
2990

        Works both with or without labels.
        """
2991
2992
        args = self.args

2993
2994
2995
        if not has_length(dataloader):
            raise ValueError("dataloader must implement a working __len__")

2996
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
2997
2998

        # if eval is called w/o train init deepspeed here
2999
        if args.deepspeed and not self.deepspeed:
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
            # from the checkpoint eventually
            deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
            self.model = deepspeed_engine.module
            self.model_wrapped = deepspeed_engine
            self.deepspeed = deepspeed_engine
            # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
            # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
            # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
            deepspeed_engine.optimizer.optimizer = None
            deepspeed_engine.lr_scheduler = None

        model = self._wrap_model(self.model, training=False)

3014
3015
3016
3017
3018
3019
3020
        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
3021
3022
3023
3024
3025
3026
3027
3028
3029

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info(f"***** Running {description} *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Batch size = {batch_size}")
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
3030
        inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
3031

3032
        world_size = max(1, args.world_size)
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3043
            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
3044
3045
3046
3047

        model.eval()

        if is_torch_tpu_available():
3048
            dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
3049

3050
        if args.past_index >= 0:
3051
3052
3053
3054
3055
3056
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3057
3058
            inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None

3059
3060
3061
3062
3063
3064
3065
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
3066
3067
3068
3069
3070
3071
            if inputs_decode is not None:
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
3072
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
3073
3074

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
3075
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
3076
3077
3078
3079
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3080
                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3081
3082

                # Set back to None to begin a new accumulation
3083
                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
3084

3085
        if args.past_index and hasattr(self, "_past"):
3086
3087
3088
3089
3090
3091
3092
3093
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
3094
            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
3095
3096
3097
3098

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
3099
        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
3100
3101

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
3102
3103
3104
3105
3106
3107
            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)