trainer.py 73.8 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 馃 Transformers from scratch or finetune it on a new task.
"""

19
import collections
20
import inspect
21
import math
Julien Chaumond's avatar
Julien Chaumond committed
22
23
24
import os
import re
import shutil
25
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
26
from pathlib import Path
27
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
28
29


30
31
# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
32
    default_hp_search_backend,
33
    hp_params,
34
    is_azureml_available,
35
    is_comet_available,
36
    is_fairscale_available,
37
    is_mlflow_available,
38
39
40
41
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
42
43
    run_hp_search_optuna,
    run_hp_search_ray,
44
)
45
46
47
48
49
50
51
52
53
54
55
56

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
Julien Chaumond's avatar
Julien Chaumond committed
57
from .modeling_utils import PreTrainedModel
Sylvain Gugger's avatar
Sylvain Gugger committed
58
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
Julien Chaumond's avatar
Julien Chaumond committed
59
from .optimization import AdamW, get_linear_schedule_with_warmup
60
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
61
62
63
64
65
66
67
68
69
70
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
71
    DistributedTensorGatherer,
Sylvain Gugger's avatar
Sylvain Gugger committed
72
73
74
75
76
77
78
79
80
81
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    get_tpu_sampler,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
82
83
84
85
86
87
88
89
90
91
92
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    set_seed,
)
Patrick von Platen's avatar
Patrick von Platen committed
93
from .training_args import TrainingArguments
Lysandre Debut's avatar
Lysandre Debut committed
94
from .utils import logging
Julien Chaumond's avatar
Julien Chaumond committed
95
96


97
_is_native_amp_available = False
98

Sylvain Gugger's avatar
Sylvain Gugger committed
99
DEFAULT_CALLBACKS = [DefaultFlowCallback]
100
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
101

102
103
104
105
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
106

107
108
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
109
    from .file_utils import is_apex_available
110
111
112
113

    if is_apex_available():
        from apex import amp
else:
114
    _is_native_amp_available = True
115
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
116

117
118
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
119

120
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
121
122
123
124
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

125
if is_tensorboard_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
126
127
128
129
    from .integrations import TensorBoardCallback

    DEFAULT_CALLBACKS.append(TensorBoardCallback)

Julien Chaumond's avatar
Julien Chaumond committed
130

131
if is_wandb_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
132
133
134
    from .integrations import WandbCallback

    DEFAULT_CALLBACKS.append(WandbCallback)
135

136
if is_comet_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
137
138
139
    from .integrations import CometCallback

    DEFAULT_CALLBACKS.append(CometCallback)
140

141
142
143
144
145
if is_mlflow_available():
    from .integrations import MLflowCallback

    DEFAULT_CALLBACKS.append(MLflowCallback)

146
147
148
149
150
151
if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

152
153
154
155
156
if is_azureml_available():
    from .integrations import AzureMLCallback

    DEFAULT_CALLBACKS.append(AzureMLCallback)

157
158
159
160
161
if is_fairscale_available():
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

Lysandre Debut's avatar
Lysandre Debut committed
162
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
163
164
165
166


class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
167
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 馃 Transformers.
168
169

    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
170
        model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
171
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
172
173
174
175
176
177

            .. note::

                :class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
                provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
                they work the same way as the 馃 Transformers models.
178
        args (:class:`~transformers.TrainingArguments`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
179
180
181
            The arguments to tweak for training. Will default to a basic instance of
            :class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in
            the current directory if not provided.
182
        data_collator (:obj:`DataCollator`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
183
184
185
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
            Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
            :func:`~transformers.DataCollatorWithPadding` otherwise.
Sylvain Gugger's avatar
Sylvain Gugger committed
186
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
187
            The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
188
            ``model.forward()`` method are automatically removed.
Sylvain Gugger's avatar
Sylvain Gugger committed
189
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
190
             The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
Sylvain Gugger's avatar
Sylvain Gugger committed
191
             ``model.forward()`` method are automatically removed.
192
193
194
195
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
196
197
198
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
199

Sylvain Gugger's avatar
Sylvain Gugger committed
200
201
            The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
            able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
Sylvain Gugger's avatar
Sylvain Gugger committed
202
203
            layers, dropout probabilities etc).
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
204
            The function that will be used to compute metrics at evaluation. Must take a
Sylvain Gugger's avatar
Sylvain Gugger committed
205
206
207
208
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in :doc:`here <callback>`.
Sylvain Gugger's avatar
Sylvain Gugger committed
209
210

            If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
Sylvain Gugger's avatar
Sylvain Gugger committed
211
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
212
            containing the optimizer and the scheduler to use. Will default to an instance of
213
            :class:`~transformers.AdamW` on your model and a scheduler given by
Sylvain Gugger's avatar
Sylvain Gugger committed
214
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
Julien Chaumond's avatar
Julien Chaumond committed
215
216
217
218
    """

    def __init__(
        self,
Sylvain Gugger's avatar
Sylvain Gugger committed
219
        model: Union[PreTrainedModel, torch.nn.Module] = None,
220
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
221
222
223
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
224
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
225
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
226
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
227
        callbacks: Optional[List[TrainerCallback]] = None,
228
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
Julien Chaumond's avatar
Julien Chaumond committed
229
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
230
231
232
233
234
235
        if args is None:
            logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
            args = TrainingArguments("tmp_trainer")
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
236
237
238
        assert (
            model is not None or model_init is not None
        ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
239
        self.model_init = model_init
240
        self.hp_name = None
241
        if model is None and model_init is not None:
242
            model = self.call_model_init()
243

244
        # Model parallel
245
246
247
248
        if model is not None and not self.args.model_parallel:
            model = model.to(args.device)

        self.model = model
249
250
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
251
252
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
253
        self.tokenizer = tokenizer
254

Julien Chaumond's avatar
Julien Chaumond committed
255
        self.compute_metrics = compute_metrics
256
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
257
258
259
260
261
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
262
263
        callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
        self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
264
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
265

266
267
268
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

Julien Chaumond's avatar
Julien Chaumond committed
269
        # Create output directory if needed
270
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
271
            os.makedirs(self.args.output_dir, exist_ok=True)
272
        if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
Lysandre Debut's avatar
Lysandre Debut committed
273
274
275
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
276
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
277
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
278

279
280
281
282
283
284
285
286
287
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

        # Enforce rules on using datasets with no __len__
        if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")

288
289
        if is_datasets_available():
            if isinstance(train_dataset, datasets.Dataset):
290
                self._remove_unused_columns(self.train_dataset, description="training")
291
            if isinstance(eval_dataset, datasets.Dataset):
292
293
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

294
295
296
297
298
299
300
301
302
303
        # Setup Sharded DDP training
        self.sharded_dpp = False
        if args.sharded_ddp:
            if args.local_rank == -1:
                raise ValueError("Using sharded DDP only works in distributed training.")
            elif not is_fairscale_available():
                raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
            else:
                self.sharded_dpp = True

304
305
306
307
308
309
310
311
312
313
314
        # Mixed precision setup
        self.use_apex = False
        self.use_amp = False
        if args.fp16:
            if args.fp16_backend == "auto":
                backend = "amp" if _is_native_amp_available else "apex"
            else:
                backend = args.fp16_backend

            if backend == "amp":
                self.use_amp = True
315
                self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
316
317
318
319
320
321
322
            else:
                if not is_apex_available():
                    raise ImportError(
                        "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
                    )
                self.use_apex = True

323
        self.state = TrainerState()
Sylvain Gugger's avatar
Sylvain Gugger committed
324
        self.control = TrainerControl()
325
326
327
        # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
        # state at each call to self.log.
        self._total_flos = None
328
        self.hp_search_backend = None
329
        self.use_tune_checkpoints = False
330
        default_label_names = (
331
            ["start_positions", "end_positions"]
332
333
334
335
            if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
            else ["labels"]
        )
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

    def add_callback(self, callback):
        """
        Add a callback to the current list of :class:`~transformer.TrainerCallback`.

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will instantiate a member of that class.
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
        Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.

        If the callback is not found, returns :obj:`None` (and no error is raised).

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will pop the first member of that class found in the list of callbacks.

        Returns:
            :class:`~transformer.TrainerCallback`: The callback removed, if found.
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
        Remove a callback from the current list of :class:`~transformer.TrainerCallback`.

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will remove the first member of that class found in the list of callbacks.
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
375

376
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
377
378
        if not self.args.remove_unused_columns:
            return
379
380
381
382
383
384
385
386
387
388
389
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
sgugger's avatar
sgugger committed
390
        dataset.set_format(type=dataset.format["type"], columns=columns)
391

392
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
393
394
395
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
            self.train_dataset, collections.abc.Sized
        ):
396
            return None
397
        elif is_torch_tpu_available():
398
            return get_tpu_sampler(self.train_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
399
        else:
400
            return (
Lysandre Debut's avatar
Lysandre Debut committed
401
402
403
404
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
405
406
407
408
409

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

Sylvain Gugger's avatar
Sylvain Gugger committed
410
411
        Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
        to distributed training if necessary) otherwise.
412
413
414
415
416
417
418
419

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
420
421
422
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
423
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
424
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
425
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
426
427
        )

428
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
429
        if is_torch_tpu_available():
430
431
432
433
434
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
435

Julien Chaumond's avatar
Julien Chaumond committed
436
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
437
438
439
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

440
441
        Subclass and override this method if you want to inject some custom behavior.

442
        Args:
443
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
444
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
445
                accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
446
        """
Julien Chaumond's avatar
Julien Chaumond committed
447
448
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
449
450
451
        elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")
        elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
452
            self._remove_unused_columns(eval_dataset, description="evaluation")
453
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
454
        eval_sampler = self._get_eval_sampler(eval_dataset)
455

456
        return DataLoader(
457
            eval_dataset,
458
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
459
            batch_size=self.args.eval_batch_size,
460
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
461
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
462
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
463
464
465
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
466
467
468
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

469
470
        Subclass and override this method if you want to inject some custom behavior.

471
        Args:
472
            test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
473
                The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
474
                ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
475
        """
476
477
478
        if not isinstance(test_dataset, collections.abc.Sized):
            raise ValueError("test_dataset must implement __len__")
        elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
479
            self._remove_unused_columns(test_dataset, description="test")
480
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
481

482
483
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
484
            test_dataset,
485
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
486
            batch_size=self.args.eval_batch_size,
487
            collate_fn=self.data_collator,
488
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
489
        )
Lysandre Debut's avatar
Lysandre Debut committed
490

491
    def create_optimizer_and_scheduler(self, num_training_steps: int):
492
493
494
        """
        Setup the optimizer and the learning rate scheduler.

495
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
496
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
497
        """
498
499
500
501
502
503
504
505
506
507
508
509
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
            if self.sharded_dpp:
                self.optimizer = OSS(
                    params=optimizer_grouped_parameters,
                    optim=AdamW,
                    lr=self.args.learning_rate,
                    betas=(self.args.adam_beta1, self.args.adam_beta2),
                    eps=self.args.adam_epsilon,
                )
            else:
                self.optimizer = AdamW(
                    optimizer_grouped_parameters,
                    lr=self.args.learning_rate,
                    betas=(self.args.adam_beta1, self.args.adam_beta2),
                    eps=self.args.adam_epsilon,
                )
525
526
527
528
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
Julien Chaumond's avatar
Julien Chaumond committed
529

530
    def num_examples(self, dataloader: DataLoader) -> int:
531
        """
532
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
533
534

        Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
535
        """
536
        return len(dataloader.dataset)
537

538
539
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
540
541
        self._trial = trial

542
543
        if self.hp_search_backend is None or trial is None:
            return
544

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
564
        self.objective = self.compute_objective(metrics.copy())
565
566
567
568
569
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
570
            if self.state.global_step % self.args.save_steps == 0:
571
                self._tune_save_checkpoint()
572
573
            tune.report(objective=self.objective, **metrics)

574
575
576
    def _tune_save_checkpoint(self):
        if not self.use_tune_checkpoints:
            return
577
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
578
            self.args.output_dir = checkpoint_dir
579
            output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
580
            self.save_model(output_dir)
581
            if self.is_world_process_zero():
582
                self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
583
584
585
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

586
587
588
589
590
591
592
    def call_model_init(self, trial=None):
        model_init_argcount = len(inspect.signature(self.model_init).parameters)
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
593
594
595
596
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
597
598
599

        return model

600
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
601
602
603
604
        """
        Main training entry point.

        Args:
605
606
607
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
608
609
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
Julien Chaumond's avatar
Julien Chaumond committed
610
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
611
612
613
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

614
615
        # Model re-init
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
616
617
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
618
619
620

            model = self.call_model_init(trial)

621
622
            if not self.args.model_parallel:
                self.model = model.to(self.args.device)
623

Sylvain Gugger's avatar
Sylvain Gugger committed
624
625
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
626

627
628
629
        # Keeping track whether we can can len() on the dataset or not
        train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)

630
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
631
        train_dataloader = self.get_train_dataloader()
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        if train_dataset_is_sized:
            num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            if self.args.max_steps > 0:
                max_steps = self.args.max_steps
                num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                    self.args.max_steps % num_update_steps_per_epoch > 0
                )
            else:
                max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(self.args.num_train_epochs)
Julien Chaumond's avatar
Julien Chaumond committed
648
        else:
649
650
651
652
            # see __init__. max_steps is set when the dataset has no __len__
            max_steps = self.args.max_steps
            num_train_epochs = 1
            num_update_steps_per_epoch = max_steps
Julien Chaumond's avatar
Julien Chaumond committed
653

654
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)
655
        self.state = TrainerState()
656
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
657
658

        # Check if saved optimizer or scheduler states exist
Sylvain Gugger's avatar
Sylvain Gugger committed
659
        self._load_optimizer_and_scheduler(model_path)
Julien Chaumond's avatar
Julien Chaumond committed
660

Sylvain Gugger's avatar
Sylvain Gugger committed
661
        # Mixed precision training with apex (torch < 1.6)
Julien Chaumond's avatar
Julien Chaumond committed
662
        model = self.model
663
        if self.use_apex:
664
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Julien Chaumond's avatar
Julien Chaumond committed
665

666
        # Multi-gpu training (should be after apex fp16 initialization)
667
        if self.args.n_gpu > 1 and not self.args.model_parallel:
Julien Chaumond's avatar
Julien Chaumond committed
668
669
670
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
671
672
673
        if self.sharded_dpp:
            model = ShardedDDP(model, self.optimizer)
        elif self.args.local_rank != -1:
Julien Chaumond's avatar
Julien Chaumond committed
674
675
676
677
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
678
679
680
681
682
                find_unused_parameters=(
                    not getattr(model.config, "gradient_checkpointing", False)
                    if isinstance(model, PreTrainedModel)
                    else True
                ),
Julien Chaumond's avatar
Julien Chaumond committed
683
            )
684
685
            # find_unused_parameters breaks checkpointing as per
            # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
Julien Chaumond's avatar
Julien Chaumond committed
686
687

        # Train!
688
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
689
690
691
692
693
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
694
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
695
            )
696
697
698
699
700
701
702

        num_examples = (
            self.num_examples(train_dataloader)
            if train_dataset_is_sized
            else total_train_batch_size * self.args.max_steps
        )

Julien Chaumond's avatar
Julien Chaumond committed
703
        logger.info("***** Running training *****")
704
705
706
707
708
709
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
        logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {max_steps}")
Julien Chaumond's avatar
Julien Chaumond committed
710

711
        self.state.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
712
713
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
714

Julien Chaumond's avatar
Julien Chaumond committed
715
        # Check if continuing training from a checkpoint
716
717
718
        if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
            self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
719
720
721
722
723
            if not self.args.ignore_data_skip:
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
                steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
            else:
                steps_trained_in_current_epoch = 0
724
725

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
726
727
728
729
730
731
732
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not self.args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
                    "batches in the first epoch."
                )
733

Sylvain Gugger's avatar
Sylvain Gugger committed
734
735
736
737
738
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
739
740
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
        self.state.trial_params = hp_params(trial) if trial is not None else None
741
742
743
744
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
745
746
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
747

748
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
749
        tr_loss = torch.tensor(0.0).to(self.args.device)
750
751
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
752
        self._globalstep_last_logged = 0
753
        self._total_flos = self.state.total_flos
Julien Chaumond's avatar
Julien Chaumond committed
754
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
755
756
757

        self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)

758
759
760
761
762
763
764
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not self.args.ignore_data_skip:
            for epoch in range(epochs_trained):
                # We just need to begin an iteration to create the randomization of the sampler.
                for _ in train_dataloader:
                    break

765
        for epoch in range(epochs_trained, num_train_epochs):
766
767
768
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

769
            if is_torch_tpu_available():
770
771
772
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
773
                epoch_iterator = parallel_loader
774
            else:
775
                epoch_iterator = train_dataloader
776

777
778
779
780
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

781
            steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
Sylvain Gugger's avatar
Sylvain Gugger committed
782
783
            self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)

Julien Chaumond's avatar
Julien Chaumond committed
784
785
786
787
788
789
790
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

Sylvain Gugger's avatar
Sylvain Gugger committed
791
792
793
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)

794
795
                if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1:
                    # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
796
797
798
799
                    with model.no_sync():
                        tr_loss += self.training_step(model, inputs)
                else:
                    tr_loss += self.training_step(model, inputs)
800
                self._total_flos += self.floating_point_ops(inputs)
Julien Chaumond's avatar
Julien Chaumond committed
801
802
803

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
804
805
                    steps_in_epoch <= self.args.gradient_accumulation_steps
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
806
                ):
807
                    if self.use_amp:
808
                        self.scaler.unscale_(self.optimizer)
809
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
810
                    elif self.use_apex:
811
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
Julien Chaumond's avatar
Julien Chaumond committed
812
813
814
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

815
                    if is_torch_tpu_available():
816
                        xm.optimizer_step(self.optimizer)
817
                    elif self.use_amp:
818
                        self.scaler.step(self.optimizer)
819
                        self.scaler.update()
Lysandre Debut's avatar
Lysandre Debut committed
820
                    else:
821
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
822

823
                    self.lr_scheduler.step()
Julien Chaumond's avatar
Julien Chaumond committed
824
                    model.zero_grad()
825
                    self.state.global_step += 1
826
                    self.state.epoch = epoch + (step + 1) / steps_in_epoch
Sylvain Gugger's avatar
Sylvain Gugger committed
827
828
                    self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)

829
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
Julien Chaumond's avatar
Julien Chaumond committed
830

Sylvain Gugger's avatar
Sylvain Gugger committed
831
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
832
                    break
833

Sylvain Gugger's avatar
Sylvain Gugger committed
834
            self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
835
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
836

837
            if self.args.tpu_metrics_debug or self.args.debug:
838
839
840
841
842
843
844
845
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
846
            if self.control.should_training_stop:
847
                break
Julien Chaumond's avatar
Julien Chaumond committed
848

849
850
851
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
852
853

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
854
855
856
857
        if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            logger.info(
                f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
            )
858
859
            if isinstance(self.model, PreTrainedModel):
                self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
860
861
                if not self.args.model_parallel:
                    self.model = self.model.to(self.args.device)
862
863
864
865
            else:
                state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
                self.model.load_state_dict(state_dict)

866
867
868
869
        if self._total_flos is not None:
            self.store_flos()
            self.log({"total_flos": self.state.total_flos})

Sylvain Gugger's avatar
Sylvain Gugger committed
870
        self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
871
872
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
Sylvain Gugger's avatar
Sylvain Gugger committed
873

874
        return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
875

876
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
Sylvain Gugger's avatar
Sylvain Gugger committed
877
878
879
        if self.control.should_log:
            logs: Dict[str, float] = {}
            tr_loss_scalar = tr_loss.item()
880
881
882
883
            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
Sylvain Gugger's avatar
Sylvain Gugger committed
884
885
886
887
888
889
            # backward compatibility for pytorch schedulers
            logs["learning_rate"] = (
                self.lr_scheduler.get_last_lr()[0]
                if version.parse(torch.__version__) >= version.parse("1.4")
                else self.lr_scheduler.get_lr()[0]
            )
890
            self._total_loss_scalar += tr_loss_scalar
891
            self._globalstep_last_logged = self.state.global_step
Sylvain Gugger's avatar
Sylvain Gugger committed
892
893
894
895
896
897
898

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            metrics = self.evaluate()
            self._report_to_hp_search(trial, epoch, metrics)
899

Sylvain Gugger's avatar
Sylvain Gugger committed
900
901
902
903
904
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

    def _save_checkpoint(self, model, trial, metrics=None):
905
906
907
908
909
910
911
        # In all cases (even distributed/parallel), self.model is always a reference
        # to the model we want to save.
        if hasattr(model, "module"):
            assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
        else:
            assert model is self.model, f"Model {model} should be a reference to self.model"
        # Save model checkpoint
912
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
913

914
915
        if self.hp_search_backend is not None and trial is not None:
            run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
916
917
918
919
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
        else:
            output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
920

921
            self.store_flos()
922
923
924
        self.save_model(output_dir)

        # Save optimizer and scheduler
925
926
        if self.sharded_dpp:
            self.optimizer.consolidate_state_dict()
927
928
929
930
931
932
933
934
935
936
937
938
939
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                reissue_pt_warnings(caught_warnings)
        elif self.is_world_process_zero():
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
            reissue_pt_warnings(caught_warnings)

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
940
        if metrics is not None and self.args.metric_for_best_model is not None:
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.is_world_process_zero():
            self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))

        # Maybe delete some older checkpoints.
        if self.is_world_process_zero():
            self._rotate_checkpoints(use_mtime=True)

Sylvain Gugger's avatar
Sylvain Gugger committed
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
    def _load_optimizer_and_scheduler(self, model_path):
        """If optimizer and scheduler states exist, load them."""
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
                optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
                with warnings.catch_warnings(record=True) as caught_warnings:
                    lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
                self.optimizer.load_state_dict(
                    torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
                )
                with warnings.catch_warnings(record=True) as caught_warnings:
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
                reissue_pt_warnings(caught_warnings)

991
992
993
994
995
996
997
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
998
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
999
1000
1001
        **kwargs
    ) -> BestRun:
        """
1002
1003
1004
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
        :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
1005

Sylvain Gugger's avatar
Sylvain Gugger committed
1006
1007
1008
1009
1010
1011
1012
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

Sylvain Gugger's avatar
Sylvain Gugger committed
1034
1035
1036
1037
                - the documentation of `optuna.create_study
                  <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
                - the documentation of `tune.run
                  <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
1038
1039

        Returns:
Tiger's avatar
Tiger committed
1040
            :class:`transformers.trainer_utils.BestRun`: All the information about the best run.
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
1052
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
1053
1054
        if backend == HPSearchBackend.RAY and not is_ray_available():
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1055
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
1056
1057
            )
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
1058
1059
1060
1061
1062
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

1063
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
1064
        self.hp_name = hp_name
1065
1066
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

1067
1068
        run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
        best_run = run_hp_search(self, n_trials, direction, **kwargs)
1069
1070
1071
1072

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
1073
    def log(self, logs: Dict[str, float]) -> None:
1074
1075
1076
1077
1078
1079
1080
1081
1082
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
        """
1083
1084
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
1085

Sylvain Gugger's avatar
Sylvain Gugger committed
1086
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
1087
1088
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
Julien Chaumond's avatar
Julien Chaumond committed
1089

sgugger's avatar
Fix CI  
sgugger committed
1090
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
1091
1092
1093
1094
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
1095
        for k, v in inputs.items():
1096
1097
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
1098

1099
1100
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
1101

1102
1103
        return inputs

1104
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
1105
        """
1106
        Perform a training step on a batch of inputs.
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
1120
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
1121
1122
1123
        """

        model.train()
1124
        inputs = self._prepare_inputs(inputs)
1125

1126
        if self.use_amp:
1127
            with autocast():
Sylvain Gugger's avatar
Sylvain Gugger committed
1128
                loss = self.compute_loss(model, inputs)
1129
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1130
            loss = self.compute_loss(model, inputs)
1131

Julien Chaumond's avatar
Julien Chaumond committed
1132
1133
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1134

Julien Chaumond's avatar
Julien Chaumond committed
1135
1136
1137
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

1138
        if self.use_amp:
1139
            self.scaler.scale(loss).backward()
1140
        elif self.use_apex:
1141
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1142
1143
1144
1145
                scaled_loss.backward()
        else:
            loss.backward()

1146
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
1147

Sylvain Gugger's avatar
Sylvain Gugger committed
1148
1149
1150
1151
1152
1153
1154
1155
    def compute_loss(self, model, inputs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)
        # Save past state if it exists
1156
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
1157
1158
1159
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
1160
        return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
Sylvain Gugger's avatar
Sylvain Gugger committed
1161

1162
1163
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1164
1165
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
1166
        """
1167
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1168
1169
1170
1171
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

1172
1173
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1174
1175
        Whether or not this process is the global main process (when training in a distributed fashion on several
        machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1176
        """
1177
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1178
1179
1180
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
1181
1182
1183

    def save_model(self, output_dir: Optional[str] = None):
        """
1184
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1185

1186
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
1187
        """
1188

1189
        if is_torch_tpu_available():
1190
            self._save_tpu(output_dir)
1191
        elif self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
1192
1193
            self._save(output_dir)

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
1205
1206
1207
1208
1209
1210
        if not isinstance(self.model, PreTrainedModel):
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
1211
        if self.tokenizer is not None and self.is_world_process_zero():
1212
            self.tokenizer.save_pretrained(output_dir)
1213

Julien Chaumond's avatar
Julien Chaumond committed
1214
1215
1216
1217
1218
1219
1220
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
1221
1222
1223
1224
1225
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
1226
        if self.tokenizer is not None and self.is_world_process_zero():
1227
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1228
1229
1230

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1231

1232
    def store_flos(self):
1233
        # Storing the number of floating-point operations that went into the model
1234
        if self._total_flos is not None:
1235
            if self.args.local_rank != -1:
1236
                self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
1237
            else:
1238
                self.state.total_flos = self._total_flos
Julien Chaumond's avatar
Julien Chaumond committed
1239
1240
1241
1242

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

1243
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1244
1245
1246
1247
1248

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1249
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1250
1251
1252
1253
1254
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
1255
1256
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
1257
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
1258
            checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = (
1259
1260
1261
                checkpoints_sorted[-1],
                checkpoints_sorted[best_model_index],
            )
Julien Chaumond's avatar
Julien Chaumond committed
1262
1263
1264
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
1265
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

1279
    def evaluate(
1280
1281
1282
1283
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
1284
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1285
        """
1286
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1287

Sylvain Gugger's avatar
Sylvain Gugger committed
1288
1289
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
        (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1290

1291
1292
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1293
        Args:
1294
            eval_dataset (:obj:`Dataset`, `optional`):
1295
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
Sylvain Gugger's avatar
Sylvain Gugger committed
1296
1297
                columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
                :obj:`__len__` method.
1298
1299
1300
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
1301
1302
1303
            metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
1304

Julien Chaumond's avatar
Julien Chaumond committed
1305
        Returns:
1306
1307
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
1308
        """
1309
1310
1311
        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")

Julien Chaumond's avatar
Julien Chaumond committed
1312
1313
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

1314
1315
1316
1317
1318
1319
        output = self.prediction_loop(
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
1320
            ignore_keys=ignore_keys,
1321
            metric_key_prefix=metric_key_prefix,
1322
        )
Lysandre Debut's avatar
Lysandre Debut committed
1323

1324
        self.log(output.metrics)
1325

1326
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1327
1328
1329
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
1330
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
1331
1332
        return output.metrics

1333
1334
1335
    def predict(
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval"
    ) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
1336
        """
1337
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1338

Sylvain Gugger's avatar
Sylvain Gugger committed
1339
1340
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
        will also return metrics, like in :obj:`evaluate()`.
1341
1342
1343

        Args:
            test_dataset (:obj:`Dataset`):
1344
                Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
1345
                ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
1346
1347
1348
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
1349
1350
1351
            metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)
1352

1353
1354
1355
1356
1357
1358
        .. note::

            If your predictions or labels have different sequence length (for instance because you're doing dynamic
            padding in a token classification task) the predictions will be padded (on the right) to allow for
            concatenation into one array. The padding index is -100.

Sylvain Gugger's avatar
Sylvain Gugger committed
1359
1360
1361
1362
1363
1364
        Returns: `NamedTuple` A namedtuple with the following keys:

            - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
            - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
            - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
              contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1365
        """
1366
1367
1368
        if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
            raise ValueError("test_dataset must implement __len__")

Julien Chaumond's avatar
Julien Chaumond committed
1369
        test_dataloader = self.get_test_dataloader(test_dataset)
1370

1371
1372
1373
        return self.prediction_loop(
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
Julien Chaumond's avatar
Julien Chaumond committed
1374

1375
    def prediction_loop(
1376
1377
1378
1379
1380
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
1381
        metric_key_prefix: str = "eval",
Julien Chaumond's avatar
Julien Chaumond committed
1382
1383
    ) -> PredictionOutput:
        """
1384
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1385
1386
1387

        Works both with or without labels.
        """
1388
1389
        if not isinstance(dataloader.dataset, collections.abc.Sized):
            raise ValueError("dataset must implement __len__")
1390
1391
1392
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1393

1394
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
1395
        # multi-gpu eval
1396
        if self.args.n_gpu > 1 and not self.args.model_parallel:
1397
1398
1399
            model = torch.nn.DataParallel(model)
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
1400

1401
        batch_size = dataloader.batch_size
1402
        num_examples = self.num_examples(dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
1403
        logger.info("***** Running %s *****", description)
1404
        logger.info("  Num examples = %d", num_examples)
1405
        logger.info("  Batch size = %d", batch_size)
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None

        world_size = 1
        if is_torch_tpu_available():
            world_size = xm.xrt_world_size()
        elif self.args.local_rank != -1:
            world_size = torch.distributed.get_world_size()
        world_size = max(1, world_size)

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
1418
1419
1420
        if not prediction_loss_only:
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
1421

Julien Chaumond's avatar
Julien Chaumond committed
1422
1423
        model.eval()

1424
        if is_torch_tpu_available():
1425
1426
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1427
        if self.args.past_index >= 0:
1428
            self._past = None
1429

Sylvain Gugger's avatar
Sylvain Gugger committed
1430
1431
        self.callback_handler.eval_dataloader = dataloader

1432
        for step, inputs in enumerate(dataloader):
1433
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
1434
            if loss is not None:
1435
1436
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
1437
            if logits is not None:
1438
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
1439
            if labels is not None:
1440
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1441
            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1442

1443
1444
1445
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
1446
1447
1448
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
1449
1450
1451
1452

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host = None, None, None

1453
1454
1455
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1456

1457
1458
        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
1459
1460
1461
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
1462
1463

        eval_loss = eval_losses_gatherer.finalize()
1464
1465
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
Lysandre Debut's avatar
Lysandre Debut committed
1466

Julien Chaumond's avatar
Julien Chaumond committed
1467
1468
1469
1470
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
1471
1472

        if eval_loss is not None:
1473
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
1474

1475
        # Prefix all keys with metric_key_prefix + '_'
1476
        for key in list(metrics.keys()):
1477
1478
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
1479
1480

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1481

1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)

1496
    def prediction_step(
1497
1498
1499
1500
1501
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.
1518
1519
1520
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
1521
1522

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1523
1524
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
1525
        """
1526
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
1527
        inputs = self._prepare_inputs(inputs)
1528
1529
1530
1531
1532
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
1533
1534

        with torch.no_grad():
1535
            if self.use_amp:
luyug's avatar
luyug committed
1536
1537
1538
1539
                with autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
1540
            if has_labels:
1541
1542
1543
1544
1545
1546
                if isinstance(outputs, dict):
                    loss = outputs["loss"].mean().detach()
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                else:
                    loss = outputs[0].mean().detach()
                    logits = outputs[1:]
1547
1548
            else:
                loss = None
1549
1550
1551
1552
1553
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                else:
                    logits = outputs
            # TODO: this needs to be fixed and made cleaner later.
1554
1555
1556
1557
1558
1559
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

1560
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
1561
1562
1563
1564
        if len(logits) == 1:
            logits = logits[0]

        if has_labels:
1565
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
Sylvain Gugger's avatar
Sylvain Gugger committed
1566
1567
1568
1569
1570
1571
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        return (loss, logits, labels)
1572
1573
1574

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1575
1576
1577
        For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
        floating point operations for every backward + forward pass. If using another model, either implement such a
        method in the model or subclass and override this method.
1578
1579
1580
1581
1582
1583
1584
1585
1586

        Args:
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

        Returns:
            :obj:`int`: The number of floating-point operations.
        """

Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1587
        model = self._actual_model(self.model)
1588
1589
1590
1591
1592
1593

        if hasattr(model, "floating_point_ops"):
            return model.floating_point_ops(inputs)

        else:
            return 0
Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612

    @staticmethod
    def _actual_model(
        model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
    ) -> torch.nn.modules.Module:
        """

        Args:
            model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
                Model object used during training

        Returns:
            :obj:`torch.nn.modules.Module`: unwrapped module
        """
        if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module
        else:
            model = model
        return model