trainer.py 71.5 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 馃 Transformers from scratch or finetune it on a new task.
"""

19
import collections
20
import inspect
21
import math
Julien Chaumond's avatar
Julien Chaumond committed
22
23
24
import os
import re
import shutil
25
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
26
from pathlib import Path
27
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
28
29


30
31
# Integrations must be imported before ML frameworks:
from .integrations import (  # isort: split
32
    default_hp_search_backend,
33
    hp_params,
34
    is_azureml_available,
35
    is_comet_available,
36
    is_mlflow_available,
37
38
39
40
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
41
42
    run_hp_search_optuna,
    run_hp_search_ray,
43
)
44
45
46
47
48
49
50
51
52
53
54
55

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
Julien Chaumond's avatar
Julien Chaumond committed
56
from .modeling_utils import PreTrainedModel
Sylvain Gugger's avatar
Sylvain Gugger committed
57
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
Julien Chaumond's avatar
Julien Chaumond committed
58
from .optimization import AdamW, get_linear_schedule_with_warmup
59
from .tokenization_utils_base import PreTrainedTokenizerBase
Sylvain Gugger's avatar
Sylvain Gugger committed
60
61
62
63
64
65
66
67
68
69
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
70
    DistributedTensorGatherer,
Sylvain Gugger's avatar
Sylvain Gugger committed
71
72
73
74
75
76
77
78
79
80
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    get_tpu_sampler,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
81
82
83
84
85
86
87
88
89
90
91
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    set_seed,
)
Patrick von Platen's avatar
Patrick von Platen committed
92
from .training_args import TrainingArguments
Lysandre Debut's avatar
Lysandre Debut committed
93
from .utils import logging
Julien Chaumond's avatar
Julien Chaumond committed
94
95


96
97
98
_use_native_amp = False
_use_apex = False

Sylvain Gugger's avatar
Sylvain Gugger committed
99
DEFAULT_CALLBACKS = [DefaultFlowCallback]
100
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Sylvain Gugger's avatar
Sylvain Gugger committed
101

102
103
104
105
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
106

107
108
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
109
    from .file_utils import is_apex_available
110
111
112
113
114
115
116

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
117

118
119
120
121
122
if version.parse(torch.__version__) < version.parse("1.2"):
    _use_ddp_no_sync = False
else:
    _use_ddp_no_sync = True

123
124
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
125

126
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
127
128
129
130
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

131
if is_tensorboard_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
132
133
134
135
    from .integrations import TensorBoardCallback

    DEFAULT_CALLBACKS.append(TensorBoardCallback)

Julien Chaumond's avatar
Julien Chaumond committed
136

137
if is_wandb_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
138
139
140
    from .integrations import WandbCallback

    DEFAULT_CALLBACKS.append(WandbCallback)
141

142
if is_comet_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
143
144
145
    from .integrations import CometCallback

    DEFAULT_CALLBACKS.append(CometCallback)
146

147
148
149
150
151
if is_mlflow_available():
    from .integrations import MLflowCallback

    DEFAULT_CALLBACKS.append(MLflowCallback)

152
153
154
155
156
157
if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

158
159
160
161
162
if is_azureml_available():
    from .integrations import AzureMLCallback

    DEFAULT_CALLBACKS.append(AzureMLCallback)

Lysandre Debut's avatar
Lysandre Debut committed
163
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
164
165
166
167


class Trainer:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
168
    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 馃 Transformers.
169
170

    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
171
        model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
172
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
Sylvain Gugger's avatar
Sylvain Gugger committed
173
174
175
176
177
178

            .. note::

                :class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
                provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
                they work the same way as the 馃 Transformers models.
179
        args (:class:`~transformers.TrainingArguments`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
180
181
182
            The arguments to tweak for training. Will default to a basic instance of
            :class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in
            the current directory if not provided.
183
        data_collator (:obj:`DataCollator`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
186
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
            Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
            :func:`~transformers.DataCollatorWithPadding` otherwise.
Sylvain Gugger's avatar
Sylvain Gugger committed
187
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
188
            The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
189
            ``model.forward()`` method are automatically removed.
Sylvain Gugger's avatar
Sylvain Gugger committed
190
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
191
             The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
Sylvain Gugger's avatar
Sylvain Gugger committed
192
             ``model.forward()`` method are automatically removed.
193
194
195
196
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
197
198
199
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
200

Sylvain Gugger's avatar
Sylvain Gugger committed
201
202
            The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
            able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
Sylvain Gugger's avatar
Sylvain Gugger committed
203
204
            layers, dropout probabilities etc).
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
205
            The function that will be used to compute metrics at evaluation. Must take a
Sylvain Gugger's avatar
Sylvain Gugger committed
206
207
208
209
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in :doc:`here <callback>`.
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211

            If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
Sylvain Gugger's avatar
Sylvain Gugger committed
212
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
213
            containing the optimizer and the scheduler to use. Will default to an instance of
214
            :class:`~transformers.AdamW` on your model and a scheduler given by
Sylvain Gugger's avatar
Sylvain Gugger committed
215
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
Julien Chaumond's avatar
Julien Chaumond committed
216
217
218
219
    """

    def __init__(
        self,
Sylvain Gugger's avatar
Sylvain Gugger committed
220
        model: Union[PreTrainedModel, torch.nn.Module] = None,
221
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
222
223
224
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
225
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
226
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
227
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
Sylvain Gugger's avatar
Sylvain Gugger committed
228
        callbacks: Optional[List[TrainerCallback]] = None,
229
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
Julien Chaumond's avatar
Julien Chaumond committed
230
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
231
232
233
234
235
236
        if args is None:
            logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
            args = TrainingArguments("tmp_trainer")
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
237
238
239
        assert (
            model is not None or model_init is not None
        ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
240
        self.model_init = model_init
241
        self.hp_name = None
242
        if model is None and model_init is not None:
243
            model = self.call_model_init()
244

245
        # Model parallel
246
247
248
249
        if model is not None and not self.args.model_parallel:
            model = model.to(args.device)

        self.model = model
250
251
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
252
253
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
254
        self.tokenizer = tokenizer
255

Julien Chaumond's avatar
Julien Chaumond committed
256
        self.compute_metrics = compute_metrics
257
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
258
259
260
261
262
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
263
264
        callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
        self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
265
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
Sylvain Gugger's avatar
Sylvain Gugger committed
266

267
268
269
        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

Julien Chaumond's avatar
Julien Chaumond committed
270
        # Create output directory if needed
271
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
272
            os.makedirs(self.args.output_dir, exist_ok=True)
273
        if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
Lysandre Debut's avatar
Lysandre Debut committed
274
275
276
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
277
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
Sylvain Gugger's avatar
Sylvain Gugger committed
278
            raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
279

280
281
282
283
284
285
286
287
288
        if args.max_steps > 0:
            logger.info("max_steps is given, it will override any value given in num_train_epochs")

        # Enforce rules on using datasets with no __len__
        if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
            raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")

289
290
        if is_datasets_available():
            if isinstance(train_dataset, datasets.Dataset):
291
                self._remove_unused_columns(self.train_dataset, description="training")
292
            if isinstance(eval_dataset, datasets.Dataset):
293
294
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

295
        self.state = TrainerState()
Sylvain Gugger's avatar
Sylvain Gugger committed
296
        self.control = TrainerControl()
297
298
299
        # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
        # state at each call to self.log.
        self._total_flos = None
300
301
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()
302
        self.hp_search_backend = None
303
        self.use_tune_checkpoints = False
304
        default_label_names = (
305
            ["start_positions", "end_positions"]
306
307
308
309
            if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
            else ["labels"]
        )
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

    def add_callback(self, callback):
        """
        Add a callback to the current list of :class:`~transformer.TrainerCallback`.

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will instantiate a member of that class.
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
        Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.

        If the callback is not found, returns :obj:`None` (and no error is raised).

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will pop the first member of that class found in the list of callbacks.

        Returns:
            :class:`~transformer.TrainerCallback`: The callback removed, if found.
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
        Remove a callback from the current list of :class:`~transformer.TrainerCallback`.

        Args:
           callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
               A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
               In the first case, will remove the first member of that class found in the list of callbacks.
        """
        self.callback_handler.remove_callback(callback)
Julien Chaumond's avatar
Julien Chaumond committed
349

350
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
351
352
        if not self.args.remove_unused_columns:
            return
353
354
355
356
357
358
359
360
361
362
363
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
sgugger's avatar
sgugger committed
364
        dataset.set_format(type=dataset.format["type"], columns=columns)
365

366
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
367
368
369
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
            self.train_dataset, collections.abc.Sized
        ):
370
            return None
371
        elif is_torch_tpu_available():
372
            return get_tpu_sampler(self.train_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
373
        else:
374
            return (
Lysandre Debut's avatar
Lysandre Debut committed
375
376
377
378
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
379
380
381
382
383

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

Sylvain Gugger's avatar
Sylvain Gugger committed
384
385
        Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
        to distributed training if necessary) otherwise.
386
387
388
389
390
391
392
393

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
394
395
396
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
397
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
398
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
399
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
400
401
        )

402
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
403
        if is_torch_tpu_available():
404
405
406
407
408
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
409

Julien Chaumond's avatar
Julien Chaumond committed
410
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
411
412
413
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

414
415
        Subclass and override this method if you want to inject some custom behavior.

416
        Args:
417
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
418
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
419
                accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
420
        """
Julien Chaumond's avatar
Julien Chaumond committed
421
422
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
423
424
425
        elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")
        elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
426
            self._remove_unused_columns(eval_dataset, description="evaluation")
427
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
428
        eval_sampler = self._get_eval_sampler(eval_dataset)
429

430
        return DataLoader(
431
            eval_dataset,
432
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
433
            batch_size=self.args.eval_batch_size,
434
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
435
            drop_last=self.args.dataloader_drop_last,
Chady Kamar's avatar
Chady Kamar committed
436
            num_workers=self.args.dataloader_num_workers,
Julien Chaumond's avatar
Julien Chaumond committed
437
438
439
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
440
441
442
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

443
444
        Subclass and override this method if you want to inject some custom behavior.

445
        Args:
446
            test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
447
                The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
448
                ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
449
        """
450
451
452
        if not isinstance(test_dataset, collections.abc.Sized):
            raise ValueError("test_dataset must implement __len__")
        elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
453
            self._remove_unused_columns(test_dataset, description="test")
454
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
455

456
457
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
458
            test_dataset,
459
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
460
            batch_size=self.args.eval_batch_size,
461
            collate_fn=self.data_collator,
462
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
463
        )
Lysandre Debut's avatar
Lysandre Debut committed
464

465
    def create_optimizer_and_scheduler(self, num_training_steps: int):
466
467
468
        """
        Setup the optimizer and the learning rate scheduler.

469
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
470
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
471
        """
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
Julien Chaumond's avatar
Julien Chaumond committed
494

495
    def num_examples(self, dataloader: DataLoader) -> int:
496
        """
497
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
498
499

        Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
500
        """
501
        return len(dataloader.dataset)
502

503
504
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
505
506
        self._trial = trial

507
508
        if self.hp_search_backend is None or trial is None:
            return
509

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
529
        self.objective = self.compute_objective(metrics.copy())
530
531
532
533
534
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
535
            if self.state.global_step % self.args.save_steps == 0:
536
                self._tune_save_checkpoint()
537
538
            tune.report(objective=self.objective, **metrics)

539
540
541
    def _tune_save_checkpoint(self):
        if not self.use_tune_checkpoints:
            return
542
        with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
543
            self.args.output_dir = checkpoint_dir
544
            output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
545
546
            self.save_model(output_dir)
            if self.is_world_master():
547
                self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
548
549
550
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

551
552
553
554
555
556
557
    def call_model_init(self, trial=None):
        model_init_argcount = len(inspect.signature(self.model_init).parameters)
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
558
559
560
561
            raise RuntimeError("model_init should have 0 or 1 argument.")

        if model is None:
            raise RuntimeError("model_init should not return None.")
562
563
564

        return model

565
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
566
567
568
569
        """
        Main training entry point.

        Args:
570
571
572
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
573
574
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
Julien Chaumond's avatar
Julien Chaumond committed
575
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
576
577
578
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

579
580
        # Model re-init
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
581
582
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
583
584
585

            model = self.call_model_init(trial)

586
587
            if not self.args.model_parallel:
                self.model = model.to(self.args.device)
588

Sylvain Gugger's avatar
Sylvain Gugger committed
589
590
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
591

592
593
594
        # Keeping track whether we can can len() on the dataset or not
        train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)

595
        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
596
        train_dataloader = self.get_train_dataloader()
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        if train_dataset_is_sized:
            num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            if self.args.max_steps > 0:
                max_steps = self.args.max_steps
                num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                    self.args.max_steps % num_update_steps_per_epoch > 0
                )
            else:
                max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(self.args.num_train_epochs)
Julien Chaumond's avatar
Julien Chaumond committed
613
        else:
614
615
616
617
            # see __init__. max_steps is set when the dataset has no __len__
            max_steps = self.args.max_steps
            num_train_epochs = 1
            num_update_steps_per_epoch = max_steps
Julien Chaumond's avatar
Julien Chaumond committed
618

619
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)
620
        self.state = TrainerState()
621
        self.state.is_hyper_param_search = trial is not None
Julien Chaumond's avatar
Julien Chaumond committed
622
623

        # Check if saved optimizer or scheduler states exist
Sylvain Gugger's avatar
Sylvain Gugger committed
624
        self._load_optimizer_and_scheduler(model_path)
Julien Chaumond's avatar
Julien Chaumond committed
625

Sylvain Gugger's avatar
Sylvain Gugger committed
626
        # Mixed precision training with apex (torch < 1.6)
Julien Chaumond's avatar
Julien Chaumond committed
627
        model = self.model
628
        if self.args.fp16 and _use_apex:
Julien Chaumond's avatar
Julien Chaumond committed
629
630
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
631
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Julien Chaumond's avatar
Julien Chaumond committed
632

633
        # Multi-gpu training (should be after apex fp16 initialization)
634
        if self.args.n_gpu > 1 and not self.args.model_parallel:
Julien Chaumond's avatar
Julien Chaumond committed
635
636
637
638
639
640
641
642
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
643
644
645
646
647
                find_unused_parameters=(
                    not getattr(model.config, "gradient_checkpointing", False)
                    if isinstance(model, PreTrainedModel)
                    else True
                ),
Julien Chaumond's avatar
Julien Chaumond committed
648
            )
649
650
        # find_unused_parameters breaks checkpointing as per
        # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
Julien Chaumond's avatar
Julien Chaumond committed
651
652

        # Train!
653
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
654
655
656
657
658
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
659
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
660
            )
661
662
663
664
665
666
667

        num_examples = (
            self.num_examples(train_dataloader)
            if train_dataset_is_sized
            else total_train_batch_size * self.args.max_steps
        )

Julien Chaumond's avatar
Julien Chaumond committed
668
        logger.info("***** Running training *****")
669
670
671
672
673
674
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
        logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {max_steps}")
Julien Chaumond's avatar
Julien Chaumond committed
675

676
        self.state.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
677
678
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
679

Julien Chaumond's avatar
Julien Chaumond committed
680
        # Check if continuing training from a checkpoint
681
682
683
        if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
            self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
684
685
686
687
688
            if not self.args.ignore_data_skip:
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
                steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
            else:
                steps_trained_in_current_epoch = 0
689
690

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
691
692
693
694
695
696
697
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not self.args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
                    "batches in the first epoch."
                )
698

Sylvain Gugger's avatar
Sylvain Gugger committed
699
700
701
702
703
        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
704
705
        self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
        self.state.trial_params = hp_params(trial) if trial is not None else None
706
707
708
709
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
710
711
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()
Julien Chaumond's avatar
Julien Chaumond committed
712

713
        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
714
        tr_loss = torch.tensor(0.0).to(self.args.device)
715
716
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
717
        self._globalstep_last_logged = 0
718
        self._total_flos = self.state.total_flos
Julien Chaumond's avatar
Julien Chaumond committed
719
        model.zero_grad()
Sylvain Gugger's avatar
Sylvain Gugger committed
720
721
722

        self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)

723
724
725
726
727
728
729
        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not self.args.ignore_data_skip:
            for epoch in range(epochs_trained):
                # We just need to begin an iteration to create the randomization of the sampler.
                for _ in train_dataloader:
                    break

730
        for epoch in range(epochs_trained, num_train_epochs):
731
732
733
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

734
            if is_torch_tpu_available():
735
736
737
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
738
                epoch_iterator = parallel_loader
739
            else:
740
                epoch_iterator = train_dataloader
741

742
743
744
745
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

746
            steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
Sylvain Gugger's avatar
Sylvain Gugger committed
747
748
            self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)

Julien Chaumond's avatar
Julien Chaumond committed
749
750
751
752
753
754
755
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

Sylvain Gugger's avatar
Sylvain Gugger committed
756
757
758
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)

759
760
761
762
763
764
765
766
767
                if (
                    ((step + 1) % self.args.gradient_accumulation_steps != 0)
                    and self.args.local_rank != -1
                    and _use_ddp_no_sync
                ):
                    with model.no_sync():
                        tr_loss += self.training_step(model, inputs)
                else:
                    tr_loss += self.training_step(model, inputs)
768
                self._total_flos += self.floating_point_ops(inputs)
Julien Chaumond's avatar
Julien Chaumond committed
769
770
771

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
772
773
                    steps_in_epoch <= self.args.gradient_accumulation_steps
                    and (step + 1) == steps_in_epoch
Julien Chaumond's avatar
Julien Chaumond committed
774
                ):
775
                    if self.args.fp16 and _use_native_amp:
776
                        self.scaler.unscale_(self.optimizer)
777
778
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
779
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
Julien Chaumond's avatar
Julien Chaumond committed
780
781
782
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

783
                    if is_torch_tpu_available():
784
                        xm.optimizer_step(self.optimizer)
785
                    elif self.args.fp16 and _use_native_amp:
786
                        self.scaler.step(self.optimizer)
787
                        self.scaler.update()
Lysandre Debut's avatar
Lysandre Debut committed
788
                    else:
789
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
790

791
                    self.lr_scheduler.step()
Julien Chaumond's avatar
Julien Chaumond committed
792
                    model.zero_grad()
793
                    self.state.global_step += 1
794
                    self.state.epoch = epoch + (step + 1) / steps_in_epoch
Sylvain Gugger's avatar
Sylvain Gugger committed
795
796
                    self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)

797
                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
Julien Chaumond's avatar
Julien Chaumond committed
798

Sylvain Gugger's avatar
Sylvain Gugger committed
799
                if self.control.should_epoch_stop or self.control.should_training_stop:
Julien Chaumond's avatar
Julien Chaumond committed
800
                    break
801

Sylvain Gugger's avatar
Sylvain Gugger committed
802
            self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
803
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
804

805
            if self.args.tpu_metrics_debug or self.args.debug:
806
807
808
809
810
811
812
813
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
814
            if self.control.should_training_stop:
815
                break
Julien Chaumond's avatar
Julien Chaumond committed
816

817
818
819
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
820
821

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
822
823
824
825
        if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            logger.info(
                f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
            )
826
827
            if isinstance(self.model, PreTrainedModel):
                self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
828
829
                if not self.args.model_parallel:
                    self.model = self.model.to(self.args.device)
830
831
832
833
            else:
                state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
                self.model.load_state_dict(state_dict)

834
835
836
837
        if self._total_flos is not None:
            self.store_flos()
            self.log({"total_flos": self.state.total_flos})

Sylvain Gugger's avatar
Sylvain Gugger committed
838
        self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
839
840
        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
Sylvain Gugger's avatar
Sylvain Gugger committed
841

842
        return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
843

844
    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
Sylvain Gugger's avatar
Sylvain Gugger committed
845
846
847
        if self.control.should_log:
            logs: Dict[str, float] = {}
            tr_loss_scalar = tr_loss.item()
848
849
850
851
            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
Sylvain Gugger's avatar
Sylvain Gugger committed
852
853
854
855
856
857
            # backward compatibility for pytorch schedulers
            logs["learning_rate"] = (
                self.lr_scheduler.get_last_lr()[0]
                if version.parse(torch.__version__) >= version.parse("1.4")
                else self.lr_scheduler.get_lr()[0]
            )
858
            self._total_loss_scalar += tr_loss_scalar
859
            self._globalstep_last_logged = self.state.global_step
Sylvain Gugger's avatar
Sylvain Gugger committed
860
861
862
863
864
865
866

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            metrics = self.evaluate()
            self._report_to_hp_search(trial, epoch, metrics)
867

Sylvain Gugger's avatar
Sylvain Gugger committed
868
869
870
871
872
        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

    def _save_checkpoint(self, model, trial, metrics=None):
873
874
875
876
877
878
879
        # In all cases (even distributed/parallel), self.model is always a reference
        # to the model we want to save.
        if hasattr(model, "module"):
            assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
        else:
            assert model is self.model, f"Model {model} should be a reference to self.model"
        # Save model checkpoint
880
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
881

882
883
        if self.hp_search_backend is not None and trial is not None:
            run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
884
885
886
887
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
        else:
            output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
888

889
            self.store_flos()
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        self.save_model(output_dir)

        # Save optimizer and scheduler
        if is_torch_tpu_available():
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                reissue_pt_warnings(caught_warnings)
        elif self.is_world_process_zero():
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
            reissue_pt_warnings(caught_warnings)

        # Determine the new best metric / best model checkpoint
Sylvain Gugger's avatar
Sylvain Gugger committed
906
        if metrics is not None and self.args.metric_for_best_model is not None:
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.is_world_process_zero():
            self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))

        # Maybe delete some older checkpoints.
        if self.is_world_process_zero():
            self._rotate_checkpoints(use_mtime=True)

Sylvain Gugger's avatar
Sylvain Gugger committed
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
    def _load_optimizer_and_scheduler(self, model_path):
        """If optimizer and scheduler states exist, load them."""
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            if is_torch_tpu_available():
                # On TPU we have to take some extra precautions to properly load the states on the right device.
                optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
                with warnings.catch_warnings(record=True) as caught_warnings:
                    lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
                reissue_pt_warnings(caught_warnings)

                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)

                self.optimizer.load_state_dict(optimizer_state)
                self.lr_scheduler.load_state_dict(lr_scheduler_state)
            else:
                self.optimizer.load_state_dict(
                    torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
                )
                with warnings.catch_warnings(record=True) as caught_warnings:
                    self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
                reissue_pt_warnings(caught_warnings)

957
958
959
960
961
962
963
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
964
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
965
966
967
        **kwargs
    ) -> BestRun:
        """
968
969
970
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
        :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
971

Sylvain Gugger's avatar
Sylvain Gugger committed
972
973
974
975
976
977
978
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

Sylvain Gugger's avatar
Sylvain Gugger committed
1000
1001
1002
1003
                - the documentation of `optuna.create_study
                  <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
                - the documentation of `tune.run
                  <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
1004
1005

        Returns:
Tiger's avatar
Tiger committed
1006
            :class:`transformers.trainer_utils.BestRun`: All the information about the best run.
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
1018
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
1019
1020
        if backend == HPSearchBackend.RAY and not is_ray_available():
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
1021
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
1022
1023
            )
        self.hp_search_backend = backend
Sylvain Gugger's avatar
Sylvain Gugger committed
1024
1025
1026
1027
1028
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

1029
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
1030
        self.hp_name = hp_name
1031
1032
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

1033
1034
        run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
        best_run = run_hp_search(self, n_trials, direction, **kwargs)
1035
1036
1037
1038

        self.hp_search_backend = None
        return best_run

Sylvain Gugger's avatar
Sylvain Gugger committed
1039
    def log(self, logs: Dict[str, float]) -> None:
1040
1041
1042
1043
1044
1045
1046
1047
1048
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
        """
1049
1050
        if self.state.epoch is not None:
            logs["epoch"] = self.state.epoch
1051

Sylvain Gugger's avatar
Sylvain Gugger committed
1052
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
1053
1054
        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
Julien Chaumond's avatar
Julien Chaumond committed
1055

sgugger's avatar
Fix CI  
sgugger committed
1056
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
1057
1058
1059
1060
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
1061
        for k, v in inputs.items():
1062
1063
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
1064

1065
1066
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
1067

1068
1069
        return inputs

1070
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
1071
        """
1072
        Perform a training step on a batch of inputs.
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
1086
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
1087
1088
1089
        """

        model.train()
1090
        inputs = self._prepare_inputs(inputs)
1091

1092
1093
        if self.args.fp16 and _use_native_amp:
            with autocast():
Sylvain Gugger's avatar
Sylvain Gugger committed
1094
                loss = self.compute_loss(model, inputs)
1095
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1096
            loss = self.compute_loss(model, inputs)
1097

Julien Chaumond's avatar
Julien Chaumond committed
1098
1099
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1100

Julien Chaumond's avatar
Julien Chaumond committed
1101
1102
1103
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

1104
1105
1106
        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
1107
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1108
1109
1110
1111
                scaled_loss.backward()
        else:
            loss.backward()

1112
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
1113

Sylvain Gugger's avatar
Sylvain Gugger committed
1114
1115
1116
1117
1118
1119
1120
1121
    def compute_loss(self, model, inputs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)
        # Save past state if it exists
1122
        # TODO: this needs to be fixed and made cleaner later.
Sylvain Gugger's avatar
Sylvain Gugger committed
1123
1124
1125
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
1126
        return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
Sylvain Gugger's avatar
Sylvain Gugger committed
1127

1128
1129
    def is_local_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1130
1131
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
1132
        """
1133
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1134
1135
1136
1137
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

1138
1139
    def is_world_process_zero(self) -> bool:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1140
1141
        Whether or not this process is the global main process (when training in a distributed fashion on several
        machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1142
        """
1143
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1144
1145
1146
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
1147
1148
1149

    def save_model(self, output_dir: Optional[str] = None):
        """
1150
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1151

1152
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
1153
        """
1154

1155
        if is_torch_tpu_available():
1156
            self._save_tpu(output_dir)
1157
        elif self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
1158
1159
            self._save(output_dir)

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        xm.rendezvous("saving_checkpoint")
1171
1172
1173
1174
1175
1176
        if not isinstance(self.model, PreTrainedModel):
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
1177
        if self.tokenizer is not None and self.is_world_process_zero():
1178
            self.tokenizer.save_pretrained(output_dir)
1179

Julien Chaumond's avatar
Julien Chaumond committed
1180
1181
1182
1183
1184
1185
1186
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
1187
1188
1189
1190
1191
            logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
            state_dict = self.model.state_dict()
            torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(output_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
1192
        if self.tokenizer is not None and self.is_world_process_zero():
1193
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1194
1195
1196

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1197

1198
    def store_flos(self):
1199
        # Storing the number of floating-point operations that went into the model
1200
        if self._total_flos is not None:
1201
            if self.args.local_rank != -1:
1202
                self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
1203
            else:
1204
                self.state.total_flos = self._total_flos
Julien Chaumond's avatar
Julien Chaumond committed
1205
1206
1207
1208

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

1209
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1210
1211
1212
1213
1214

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1215
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1216
1217
1218
1219
1220
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
1221
1222
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
1223
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
1224
            checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = (
1225
1226
1227
                checkpoints_sorted[-1],
                checkpoints_sorted[best_model_index],
            )
Julien Chaumond's avatar
Julien Chaumond committed
1228
1229
1230
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
1231
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

1245
1246
1247
    def evaluate(
        self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
    ) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1248
        """
1249
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1250

Sylvain Gugger's avatar
Sylvain Gugger committed
1251
1252
        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
        (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1253

1254
1255
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1256
        Args:
1257
            eval_dataset (:obj:`Dataset`, `optional`):
1258
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
Sylvain Gugger's avatar
Sylvain Gugger committed
1259
1260
                columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
                :obj:`__len__` method.
1261
1262
1263
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
1264

Julien Chaumond's avatar
Julien Chaumond committed
1265
        Returns:
1266
1267
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
Julien Chaumond's avatar
Julien Chaumond committed
1268
        """
1269
1270
1271
        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")

Julien Chaumond's avatar
Julien Chaumond committed
1272
1273
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

1274
1275
1276
1277
1278
1279
        output = self.prediction_loop(
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
1280
            ignore_keys=ignore_keys,
1281
        )
Lysandre Debut's avatar
Lysandre Debut committed
1282

1283
        self.log(output.metrics)
1284

1285
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1286
1287
1288
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Sylvain Gugger's avatar
Sylvain Gugger committed
1289
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
Julien Chaumond's avatar
Julien Chaumond committed
1290
1291
        return output.metrics

1292
    def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
Julien Chaumond's avatar
Julien Chaumond committed
1293
        """
1294
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1295

Sylvain Gugger's avatar
Sylvain Gugger committed
1296
1297
        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
        will also return metrics, like in :obj:`evaluate()`.
1298
1299
1300

        Args:
            test_dataset (:obj:`Dataset`):
1301
                Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
1302
                ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
1303
1304
1305
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
1306

1307
1308
1309
1310
1311
1312
        .. note::

            If your predictions or labels have different sequence length (for instance because you're doing dynamic
            padding in a token classification task) the predictions will be padded (on the right) to allow for
            concatenation into one array. The padding index is -100.

Sylvain Gugger's avatar
Sylvain Gugger committed
1313
1314
1315
1316
1317
1318
        Returns: `NamedTuple` A namedtuple with the following keys:

            - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
            - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
            - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
              contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1319
        """
1320
1321
1322
        if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
            raise ValueError("test_dataset must implement __len__")

Julien Chaumond's avatar
Julien Chaumond committed
1323
        test_dataloader = self.get_test_dataloader(test_dataset)
1324

1325
        return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
Julien Chaumond's avatar
Julien Chaumond committed
1326

1327
    def prediction_loop(
1328
1329
1330
1331
1332
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
Julien Chaumond's avatar
Julien Chaumond committed
1333
1334
    ) -> PredictionOutput:
        """
1335
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1336
1337
1338

        Works both with or without labels.
        """
1339
1340
        if not isinstance(dataloader.dataset, collections.abc.Sized):
            raise ValueError("dataset must implement __len__")
1341
1342
1343
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1344

1345
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
1346
        # multi-gpu eval
1347
        if self.args.n_gpu > 1 and not self.args.model_parallel:
1348
1349
1350
            model = torch.nn.DataParallel(model)
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
1351

1352
        batch_size = dataloader.batch_size
1353
        num_examples = self.num_examples(dataloader)
Julien Chaumond's avatar
Julien Chaumond committed
1354
        logger.info("***** Running %s *****", description)
1355
        logger.info("  Num examples = %d", num_examples)
1356
        logger.info("  Batch size = %d", batch_size)
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None

        world_size = 1
        if is_torch_tpu_available():
            world_size = xm.xrt_world_size()
        elif self.args.local_rank != -1:
            world_size = torch.distributed.get_world_size()
        world_size = max(1, world_size)

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
1369
1370
1371
        if not prediction_loss_only:
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
1372

Julien Chaumond's avatar
Julien Chaumond committed
1373
1374
        model.eval()

1375
        if is_torch_tpu_available():
1376
1377
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1378
        if self.args.past_index >= 0:
1379
            self._past = None
1380

Sylvain Gugger's avatar
Sylvain Gugger committed
1381
1382
        self.callback_handler.eval_dataloader = dataloader

1383
        for step, inputs in enumerate(dataloader):
1384
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
1385
            if loss is not None:
1386
1387
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
1388
            if logits is not None:
1389
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
1390
            if labels is not None:
1391
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1392
            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
Julien Chaumond's avatar
Julien Chaumond committed
1393

1394
1395
1396
            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
1397
1398
1399
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
1400
1401
1402
1403

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host = None, None, None

1404
1405
1406
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1407

1408
1409
        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
1410
1411
1412
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
1413
1414

        eval_loss = eval_losses_gatherer.finalize()
1415
1416
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
Lysandre Debut's avatar
Lysandre Debut committed
1417

Julien Chaumond's avatar
Julien Chaumond committed
1418
1419
1420
1421
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
1422
1423
1424

        if eval_loss is not None:
            metrics["eval_loss"] = eval_loss.mean().item()
1425
1426
1427
1428
1429

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
1430
1431

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1432

1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
    def _gather_and_numpify(self, tensors, name):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        elif self.args.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)

1447
    def prediction_step(
1448
1449
1450
1451
1452
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.
1469
1470
1471
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
1472
1473

        Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
1474
1475
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
1476
        """
1477
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
1478
        inputs = self._prepare_inputs(inputs)
1479
1480
1481
1482
1483
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
1484
1485

        with torch.no_grad():
luyug's avatar
luyug committed
1486
1487
1488
1489
1490
            if self.args.fp16 and _use_native_amp:
                with autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
1491
            if has_labels:
1492
1493
1494
1495
1496
1497
                if isinstance(outputs, dict):
                    loss = outputs["loss"].mean().detach()
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                else:
                    loss = outputs[0].mean().detach()
                    logits = outputs[1:]
1498
1499
            else:
                loss = None
1500
1501
1502
1503
1504
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                else:
                    logits = outputs
            # TODO: this needs to be fixed and made cleaner later.
1505
1506
1507
1508
1509
1510
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

1511
        logits = nested_detach(logits)
Sylvain Gugger's avatar
Sylvain Gugger committed
1512
1513
1514
1515
        if len(logits) == 1:
            logits = logits[0]

        if has_labels:
1516
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
Sylvain Gugger's avatar
Sylvain Gugger committed
1517
1518
1519
1520
1521
1522
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        return (loss, logits, labels)
1523
1524
1525

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1526
1527
1528
        For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
        floating point operations for every backward + forward pass. If using another model, either implement such a
        method in the model or subclass and override this method.
1529
1530
1531
1532
1533
1534
1535
1536
1537

        Args:
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

        Returns:
            :obj:`int`: The number of floating-point operations.
        """

Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1538
        model = self._actual_model(self.model)
1539
1540
1541
1542
1543
1544

        if hasattr(model, "floating_point_ops"):
            return model.floating_point_ops(inputs)

        else:
            return 0
Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563

    @staticmethod
    def _actual_model(
        model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
    ) -> torch.nn.modules.Module:
        """

        Args:
            model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
                Model object used during training

        Returns:
            :obj:`torch.nn.modules.Module`: unwrapped module
        """
        if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
            model = model.module
        else:
            model = model
        return model