trainer.py 65.7 KB
Newer Older
1
import inspect
2
import json
3
import math
Julien Chaumond's avatar
Julien Chaumond committed
4
5
6
import os
import re
import shutil
7
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
8
9
from contextlib import contextmanager
from pathlib import Path
10
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
11
12
13

import numpy as np
import torch
14
from packaging import version
Julien Chaumond's avatar
Julien Chaumond committed
15
16
17
18
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
19
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
20
from tqdm.auto import tqdm, trange
Julien Chaumond's avatar
Julien Chaumond committed
21

22
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
23
from .file_utils import is_datasets_available, is_torch_tpu_available
24
25
26
27
28
29
30
from .integrations import (
    default_hp_search_backend,
    is_comet_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
31
32
    run_hp_search_optuna,
    run_hp_search_ray,
33
)
Julien Chaumond's avatar
Julien Chaumond committed
34
35
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
36
from .tokenization_utils_base import PreTrainedTokenizerBase
37
38
39
40
41
42
43
44
45
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
46
47
    distributed_broadcast_scalars,
    distributed_concat,
48
49
    set_seed,
)
Patrick von Platen's avatar
Patrick von Platen committed
50
from .training_args import TrainingArguments
Lysandre Debut's avatar
Lysandre Debut committed
51
from .utils import logging
Julien Chaumond's avatar
Julien Chaumond committed
52
53


54
55
56
57
58
59
60
61
62
63
64
65
66
_use_native_amp = False
_use_apex = False

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
    from transformers.file_utils import is_apex_available

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
67

68
69
if is_datasets_available():
    import datasets
Julien Chaumond's avatar
Julien Chaumond committed
70

71
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
72
73
74
75
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

76
if is_tensorboard_available():
Julien Chaumond's avatar
Julien Chaumond committed
77
    try:
78
        from torch.utils.tensorboard import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
79
    except ImportError:
80
        from tensorboardX import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
81

82
if is_wandb_available():
83
84
    import wandb

85
86
if is_comet_available():
    import comet_ml
87

88
89
90
91
92
93
if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

Lysandre Debut's avatar
Lysandre Debut committed
94
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
95
96
97
98
99


@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
100
    Decorator to make all processes in distributed training wait for each local_master to do something.
101
102
103

    Args:
        local_rank (:obj:`int`): The rank of the local process.
Julien Chaumond's avatar
Julien Chaumond committed
104
105
106
107
108
109
110
111
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
Teven's avatar
Teven committed
144
145
146
        assert (
            len(indices) == self.total_size
        ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
147
148
149

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
Teven's avatar
Teven committed
150
151
        assert (
            len(indices) == self.num_samples
152
        ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
153
154
155
156
157
158
159

        return iter(indices)

    def __len__(self):
        return self.num_samples


Lysandre Debut's avatar
Lysandre Debut committed
160
161
162
163
164
165
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


Julien Chaumond's avatar
Julien Chaumond committed
166
167
168
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
169
170
171
    optimized for 馃 Transformers.

    Args:
172
173
174
175
176
        model (:class:`~transformers.PreTrainedModel`, `optional`):
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
        args (:class:`~transformers.TrainingArguments`, `optional`):
            The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
            with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
177
        data_collator (:obj:`DataCollator`, `optional`):
178
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or
179
180
            :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
            provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
Sylvain Gugger's avatar
Sylvain Gugger committed
181
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
182
            The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
183
            ``model.forward()`` method are automatically removed.
Sylvain Gugger's avatar
Sylvain Gugger committed
184
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
185
             The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
186
            ``model.forward()`` method are automatically removed.
187
188
189
190
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
191
192
193
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
194
195
196
197
198
199
200
201
202
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
203
204
        kwargs:
            Deprecated keyword arguments.
Julien Chaumond's avatar
Julien Chaumond committed
205
206
207
208
    """

    def __init__(
        self,
209
210
        model: PreTrainedModel = None,
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
211
212
213
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
214
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
215
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
216
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
217
        tb_writer: Optional["SummaryWriter"] = None,
218
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
219
        **kwargs,
Julien Chaumond's avatar
Julien Chaumond committed
220
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
221
222
223
224
225
226
        if args is None:
            logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
            args = TrainingArguments("tmp_trainer")
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
227
228
229
230
231
232
        assert (
            model is not None or model_init is not None
        ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
        if model is None and model_init is not None:
            model = model_init()
        self.model = model.to(args.device) if model is not None else None
233
234
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
235
236
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
237
        self.tokenizer = tokenizer
238
        self.model_init = model_init
Julien Chaumond's avatar
Julien Chaumond committed
239
        self.compute_metrics = compute_metrics
240
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
241
242
243
244
245
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
246
        self.tb_writer = tb_writer
247
        self.log_history = []
248
249
250
251
252
253
254
255
        if "prediction_loss_only" in kwargs:
            warnings.warn(
                "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
                FutureWarning,
            )
            self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

256
        if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
257
258
259
260
261
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
262
263
264
265

        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

Julien Chaumond's avatar
Julien Chaumond committed
266
        # Create output directory if needed
267
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
268
            os.makedirs(self.args.output_dir, exist_ok=True)
269
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
270
271
272
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
273
274
275
276
277
278
279
280
281
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
282

283
284
        if is_datasets_available():
            if isinstance(train_dataset, datasets.Dataset):
285
                self._remove_unused_columns(self.train_dataset, description="training")
286
            if isinstance(eval_dataset, datasets.Dataset):
287
288
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

289
290
        self.global_step = None
        self.epoch = None
291
        self.total_flos = None
292
293
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()
294
        self.hp_search_backend = None
295
        self.use_tune_checkpoints = False
Julien Chaumond's avatar
Julien Chaumond committed
296

297
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
298
299
        if not self.args.remove_unused_columns:
            return
300
301
302
303
304
305
306
307
308
309
310
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
sgugger's avatar
sgugger committed
311
        dataset.set_format(type=dataset.format["type"], columns=columns)
312

313
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
314
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
315
            return None
316
        elif is_torch_tpu_available():
317
            return get_tpu_sampler(self.train_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
318
        else:
319
            return (
Lysandre Debut's avatar
Lysandre Debut committed
320
321
322
323
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
        (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
339
340
341
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
342
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
343
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
344
345
        )

346
347
348
349
350
351
352
353
354
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
355

Julien Chaumond's avatar
Julien Chaumond committed
356
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
357
358
359
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

360
361
362
363
364
        Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

365
        Args:
366
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
367
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
368
                accepted by the ``model.forward()`` method are automatically removed.
369
        """
Julien Chaumond's avatar
Julien Chaumond committed
370
371
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
372
        elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
373
            self._remove_unused_columns(eval_dataset, description="evaluation")
374
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
375
        eval_sampler = self._get_eval_sampler(eval_dataset)
376

377
        return DataLoader(
378
            eval_dataset,
379
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
380
            batch_size=self.args.eval_batch_size,
381
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
382
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
383
384
385
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
386
387
388
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

389
390
391
392
393
        Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

394
        Args:
395
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
396
                The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
397
                ``model.forward()`` method are automatically removed.
398
        """
399
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
400
            self._remove_unused_columns(test_dataset, description="test")
401
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
402

403
404
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
405
            test_dataset,
406
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
407
            batch_size=self.args.eval_batch_size,
408
            collate_fn=self.data_collator,
409
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
410
        )
Lysandre Debut's avatar
Lysandre Debut committed
411

412
    def create_optimizer_and_scheduler(self, num_training_steps: int):
413
414
415
        """
        Setup the optimizer and the learning rate scheduler.

416
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
417
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
418
        """
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
Julien Chaumond's avatar
Julien Chaumond committed
441

442
    def setup_wandb(self):
443
444
445
        """
        Setup the optional Weights & Biases (`wandb`) integration.

446
447
        One can subclass and override this method to customize the setup if needed. Find more information
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
448
449
450
451
452
453
454
455
456

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
457
        """
458
459
460
461
462
463
464
        if hasattr(self, "_setup_wandb"):
            warnings.warn(
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
                FutureWarning,
            )
            return self._setup_wandb()

465
        if self.is_world_process_zero():
466
467
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
468
            )
469
470
471
472
473
            try:
                combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
            except AttributeError:
                # in case the model has no config
                combined_dict = {**self.args.to_sanitized_dict()}
474
475
476
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
            )
477
478
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
479
480
481
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )
482

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    def setup_comet(self):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE:
                (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME:
                (Optional): str - Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY:
                (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"

        For a number of configurable items in the environment,
        see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
        """
        if self.is_world_master():
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(self.model, framework="transformers")
                experiment._log_parameters(self.args, prefix="args/", framework="transformers")
                experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")

514
    def num_examples(self, dataloader: DataLoader) -> int:
515
        """
516
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
517
        """
518
        return len(dataloader.dataset)
519

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    def _setup_loggers(self):
        if self._loggers_initialized:
            return
        if is_wandb_available():
            self.setup_wandb()
        elif os.environ.get("WANDB_DISABLED") != "true":
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        if is_comet_available():
            self.setup_comet()
        elif os.environ.get("COMET_MODE") != "DISABLED":
            logger.info(
                "To use comet_ml logging, run `pip/conda install comet_ml` "
                "see https://www.comet.ml/docs/python-sdk/huggingface/"
            )
        self._loggers_initialized = True

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
        if self.hp_search_backend is None or trial is None:
            return
        params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
        self.objective = self.compute_objective(metrics)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
568
569
            if self.global_step % self.args.save_steps == 0:
                self._tune_save_checkpoint()
570
571
            tune.report(objective=self.objective, **metrics)

572
573
574
575
576
577
578
579
580
581
582
    def _tune_save_checkpoint(self):
        if not self.use_tune_checkpoints:
            return
        with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
            self.args.output_dir = checkpoint_dir
            output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
            self.save_model(output_dir)
            if self.is_world_master():
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

583
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
584
585
586
587
        """
        Main training entry point.

        Args:
588
589
590
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
591
592
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
Julien Chaumond's avatar
Julien Chaumond committed
593
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
594
595
596
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

597
598
        # Model re-init
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
599
600
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
601
602
603
            model = self.model_init()
            self.model = model.to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
604
605
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
606
607

        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
608
609
610
611
612
613
614
615
616
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
617
            self.args.max_steps = t_total
Julien Chaumond's avatar
Julien Chaumond committed
618

619
        self.create_optimizer_and_scheduler(num_training_steps=t_total)
Julien Chaumond's avatar
Julien Chaumond committed
620
621
622
623
624
625
626
627

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
628
            self.optimizer.load_state_dict(
629
630
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
631
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
Julien Chaumond's avatar
Julien Chaumond committed
632
633

        model = self.model
634
        if self.args.fp16 and _use_apex:
Julien Chaumond's avatar
Julien Chaumond committed
635
636
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
637
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Julien Chaumond's avatar
Julien Chaumond committed
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
654
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
Julien Chaumond's avatar
Julien Chaumond committed
655
656

        # Train!
657
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
658
659
660
661
662
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
663
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
664
            )
Julien Chaumond's avatar
Julien Chaumond committed
665
        logger.info("***** Running training *****")
666
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
Julien Chaumond's avatar
Julien Chaumond committed
667
        logger.info("  Num Epochs = %d", num_train_epochs)
668
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
Lysandre Debut's avatar
Lysandre Debut committed
669
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
670
671
672
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

673
674
        self.global_step = 0
        self.epoch = 0
675
        self.total_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
676
677
678
679
680
681
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
682
                self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
683
684
                self.total_flos = getattr(model.config, "total_flos", 0)

685
686
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
Julien Chaumond's avatar
Julien Chaumond committed
687
688
689
690
691
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
692
                logger.info("  Continuing training from global step %d", self.global_step)
693
                logger.info("  Continuing training from %d non-embedding floating-point operations", self.total_flos)
Julien Chaumond's avatar
Julien Chaumond committed
694
695
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
696
                self.global_step = 0
697
                self.total_flos = 0
Julien Chaumond's avatar
Julien Chaumond committed
698
699
                logger.info("  Starting fine-tuning.")

700
701
        tr_loss = torch.tensor(0.0).to(self.args.device)
        logging_loss_scalar = 0.0
Julien Chaumond's avatar
Julien Chaumond committed
702
        model.zero_grad()
703
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
704
705
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
706
707
708
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

709
            if is_torch_tpu_available():
710
711
712
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
713
                epoch_iterator = parallel_loader
714
            else:
715
                epoch_iterator = train_dataloader
716

717
718
719
720
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

721
            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
Julien Chaumond's avatar
Julien Chaumond committed
722
723
724
725
726
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
727
                    epoch_pbar.update(1)
Julien Chaumond's avatar
Julien Chaumond committed
728
729
                    continue

730
                tr_loss += self.training_step(model, inputs)
731
                self.total_flos += self.floating_point_ops(inputs)
Julien Chaumond's avatar
Julien Chaumond committed
732
733
734
735
736
737

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
738
                    if self.args.fp16 and _use_native_amp:
739
                        self.scaler.unscale_(self.optimizer)
740
741
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
742
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
Julien Chaumond's avatar
Julien Chaumond committed
743
744
745
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

746
                    if is_torch_tpu_available():
747
                        xm.optimizer_step(self.optimizer)
748
                    elif self.args.fp16 and _use_native_amp:
749
                        self.scaler.step(self.optimizer)
750
                        self.scaler.update()
Lysandre Debut's avatar
Lysandre Debut committed
751
                    else:
752
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
753

754
                    self.lr_scheduler.step()
Julien Chaumond's avatar
Julien Chaumond committed
755
                    model.zero_grad()
756
757
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)
Julien Chaumond's avatar
Julien Chaumond committed
758

759
760
761
762
                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
763
764
                        tr_loss_scalar = tr_loss.item()
                        logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
765
766
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
767
                            self.lr_scheduler.get_last_lr()[0]
768
                            if version.parse(torch.__version__) >= version.parse("1.4")
769
                            else self.lr_scheduler.get_lr()[0]
770
                        )
771
                        logging_loss_scalar = tr_loss_scalar
772

773
                        self.log(logs)
774

775
                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
776
777
                        metrics = self.evaluate()
                        self._report_to_hp_search(trial, epoch, metrics)
778

779
780
781
782
                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
Teven's avatar
Teven committed
783
784
785
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
786
                        else:
Teven's avatar
Teven committed
787
                            assert model is self.model, f"Model {model} should be a reference to self.model"
788
                        # Save model checkpoint
789
790
791
792
793
794
795
796
797
                        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
                        if self.hp_search_backend is not None and trial is not None:
                            run_id = (
                                trial.number
                                if self.hp_search_backend == HPSearchBackend.OPTUNA
                                else tune.get_trial_id()
                            )
                            checkpoint_folder += f"-run-{run_id}"
                        output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
798
799
800

                        self.save_model(output_dir)

801
                        if self.is_world_process_zero():
802
                            self._rotate_checkpoints(use_mtime=True)
803

804
                        if is_torch_tpu_available():
805
                            xm.rendezvous("saving_optimizer_states")
806
807
808
809
810
                            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_process_zero():
                            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
Julien Chaumond's avatar
Julien Chaumond committed
811

812
                epoch_pbar.update(1)
Sylvain Gugger's avatar
Sylvain Gugger committed
813
                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
814
                    break
815
816
            epoch_pbar.close()
            train_pbar.update(1)
817
            if self.args.tpu_metrics_debug or self.args.debug:
818
819
820
821
822
823
824
825
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
826
827
            if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                break
Julien Chaumond's avatar
Julien Chaumond committed
828

829
        train_pbar.close()
Julien Chaumond's avatar
Julien Chaumond committed
830
831
        if self.tb_writer:
            self.tb_writer.close()
832
833
834
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
835
836

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
837
        return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
838

839
840
841
842
843
844
845
846
847
848
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
        **kwargs
    ) -> BestRun:
        """
849
850
851
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
        :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
852

Sylvain Gugger's avatar
Sylvain Gugger committed
853
854
855
856
857
858
859
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

881
                - the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
                - the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__

        Returns:
            :class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
897
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
898
899
        if backend == HPSearchBackend.RAY and not is_ray_available():
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
900
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
901
902
903
            )
        self.hp_search_backend = backend

Sylvain Gugger's avatar
Sylvain Gugger committed
904
905
906
907
908
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

909
910
911
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

912
913
        run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
        best_run = run_hp_search(self, n_trials, direction, **kwargs)
914
915
916
917

        self.hp_search_backend = None
        return best_run

918
919
920
921
922
923
924
925
926
927
928
929
    def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
930
931
932
        # Set up loggers like W&B or Comet ML
        self._setup_loggers()

933
934
935
936
937
938
939
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

940
941
        if self.epoch is not None:
            logs["epoch"] = self.epoch
942
943
944
945
946
947
948
        if self.total_flos is not None:
            if self.args.local_rank != -1:
                total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
            else:
                total_flos = self.total_flos
            if total_flos > 0:
                logs["total_flos"] = self.total_flos
949
950
951
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
952
953
        if self.tb_writer:
            for k, v in logs.items():
954
955
956
957
958
959
960
961
962
963
964
965
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
966
            self.tb_writer.flush()
967
        if is_wandb_available():
968
            if self.is_world_process_zero():
969
                wandb.log(logs, step=self.global_step)
970
971
972
973
974
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
975
        output = {**logs, **{"step": self.global_step}}
976
977
        if self.is_world_process_zero():
            self.log_history.append(output)
978
979
980
        if iterator is not None:
            iterator.write(output)
        else:
981
            print(output)
Julien Chaumond's avatar
Julien Chaumond committed
982

sgugger's avatar
Fix CI  
sgugger committed
983
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
984
985
986
987
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
988
        for k, v in inputs.items():
989
990
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
991

992
993
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
994

995
996
        return inputs

997
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
998
        """
999
        Perform a training step on a batch of inputs.
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
1013
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
1014
1015
1016
1017
1018
1019
        """
        if hasattr(self, "_training_step"):
            warnings.warn(
                "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
                FutureWarning,
            )
1020
            return self._training_step(model, inputs, self.optimizer)
1021
1022

        model.train()
1023
        inputs = self._prepare_inputs(inputs)
1024

1025
1026
        if self.args.fp16 and _use_native_amp:
            with autocast():
Sylvain Gugger's avatar
Sylvain Gugger committed
1027
                loss = self.compute_loss(model, inputs)
1028
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
1029
            loss = self.compute_loss(model, inputs)
1030

Julien Chaumond's avatar
Julien Chaumond committed
1031
1032
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1033

Julien Chaumond's avatar
Julien Chaumond committed
1034
1035
1036
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

1037
1038
1039
        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
1040
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1041
1042
1043
1044
                scaled_loss.backward()
        else:
            loss.backward()

1045
        return loss.detach()
Julien Chaumond's avatar
Julien Chaumond committed
1046

Sylvain Gugger's avatar
Sylvain Gugger committed
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
    def compute_loss(self, model, inputs):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        outputs = model(**inputs)
        # Save past state if it exists
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        return outputs[0]

Lysandre Debut's avatar
Lysandre Debut committed
1060
    def is_local_master(self) -> bool:
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
        return self.is_local_process_zero()

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
1077
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1078
1079
1080
1081
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

Julien Chaumond's avatar
Julien Chaumond committed
1082
1083
    def is_world_master(self) -> bool:
        """
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
        return self.is_world_process_zero()

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1098
        """
1099
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1100
1101
1102
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
1103
1104
1105

    def save_model(self, output_dir: Optional[str] = None):
        """
1106
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1107

1108
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
1109
        """
1110

1111
        if is_torch_tpu_available():
1112
            self._save_tpu(output_dir)
1113
        elif self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
1114
1115
            self._save(output_dir)

1116
1117
1118
1119
1120
1121
1122
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1123
1124
1125
            json.dump(
                self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
            )
1126
1127
1128
1129
1130
1131
1132

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
1133
        self._store_flos()
1134
        self.model.save_pretrained(output_dir)
1135
1136
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
1137

Julien Chaumond's avatar
Julien Chaumond committed
1138
1139
1140
1141
1142
1143
1144
1145
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
1146
        self._store_flos()
Julien Chaumond's avatar
Julien Chaumond committed
1147
        self.model.save_pretrained(output_dir)
1148
1149
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1150
1151
1152

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
        json.dump(
            self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
        )

    def _store_flos(self):
        # Storing the number of floating-point operations that went into the model
        if self.total_flos is not None:
            if self.args.local_rank != -1:
                total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
            else:
                total_flos = self.total_flos
            if total_flos > 0:
                self.model.config.total_flos = total_flos
Julien Chaumond's avatar
Julien Chaumond committed
1166
1167
1168
1169

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

1170
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1171
1172
1173
1174
1175

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1176
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1177
1178
1179
1180
1181
1182
1183
1184
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
1185
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

1199
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1200
        """
1201
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1202
1203

        The calling script will be responsible for providing a method to compute metrics, as they are
1204
        task-dependent (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1205

1206
1207
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1208
        Args:
1209
            eval_dataset (:obj:`Dataset`, `optional`):
1210
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
1211
                columns not accepted by the ``model.forward()`` method are automatically removed.
1212

Julien Chaumond's avatar
Julien Chaumond committed
1213
        Returns:
1214
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
Julien Chaumond's avatar
Julien Chaumond committed
1215
1216
1217
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

1218
        output = self.prediction_loop(eval_dataloader, description="Evaluation")
Lysandre Debut's avatar
Lysandre Debut committed
1219

1220
        self.log(output.metrics)
1221

1222
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1223
1224
1225
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Julien Chaumond's avatar
Julien Chaumond committed
1226
1227
1228
1229
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
1230
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1231
1232

        Depending on the dataset and your use case, your test dataset may contain labels.
1233
1234
1235
1236
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
1237
                Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
1238
                ``model.forward()`` method are automatically removed.
1239

1240
1241
1242
1243
1244
1245
1246
1247
        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1248
1249
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
1250

1251
        return self.prediction_loop(test_dataloader, description="Prediction")
Julien Chaumond's avatar
Julien Chaumond committed
1252

1253
    def prediction_loop(
Julien Chaumond's avatar
Julien Chaumond committed
1254
1255
1256
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
1257
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1258
1259
1260

        Works both with or without labels.
        """
1261
1262
1263
1264
1265
1266
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
Julien Chaumond's avatar
Julien Chaumond committed
1267

1268
1269
1270
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1271

1272
1273
1274
1275
1276
1277
1278
        assert not getattr(
            self.model.config, "output_attentions", False
        ), "The prediction loop does not work with `output_attentions=True`."
        assert not getattr(
            self.model.config, "output_hidden_states", False
        ), "The prediction loop does not work with `output_hidden_states=True`."

1279
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
1280
        # multi-gpu eval
1281
1282
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
Julien Chaumond's avatar
Julien Chaumond committed
1283
1284
        else:
            model = self.model
1285
1286
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
1287

1288
        batch_size = dataloader.batch_size
Julien Chaumond's avatar
Julien Chaumond committed
1289
        logger.info("***** Running %s *****", description)
1290
1291
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
1292
        eval_losses: List[float] = []
1293
1294
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
Julien Chaumond's avatar
Julien Chaumond committed
1295
1296
        model.eval()

1297
        if is_torch_tpu_available():
1298
1299
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1300
        if self.args.past_index >= 0:
1301
            self._past = None
1302

1303
1304
        disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
1305
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
1306
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
1307
            if loss is not None:
1308
                eval_losses.extend([loss] * batch_size)
1309
            if logits is not None:
1310
                preds = logits if preds is None else tuple(torch.cat((p, l), dim=0) for p, l in zip(preds, logits))
1311
1312
            if labels is not None:
                label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
1313

1314
1315
1316
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1317

1318
1319
1320
        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
1321
                preds = tuple(distributed_concat(p, num_total_examples=self.num_examples(dataloader)) for p in preds)
1322
            if label_ids is not None:
1323
                label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
1324
        elif is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1325
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
1326
            if preds is not None:
1327
                preds = tuple(xm.mesh_reduce(f"eval_preds_{i}", p, torch.cat) for i, p in enumerate(preds))
1328
1329
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
1330
1331
            if eval_losses is not None:
                eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
1332
1333
1334

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
1335
1336
1337
            preds = tuple(p.cpu().numpy() for p in preds)
            if len(preds) == 1:
                preds = preds[0]
1338
1339
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()
Lysandre Debut's avatar
Lysandre Debut committed
1340

Julien Chaumond's avatar
Julien Chaumond committed
1341
1342
1343
1344
1345
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
1346
1347
1348
1349
1350
1351
1352
1353
            if self.args.local_rank != -1:
                metrics["eval_loss"] = (
                    distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
                    .mean()
                    .item()
                )
            else:
                metrics["eval_loss"] = np.mean(eval_losses)
1354
1355
1356
1357
1358

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
1359
1360

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1361

1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

1387
        inputs = self._prepare_inputs(inputs)
1388
1389
1390
1391

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
1392
1393
1394
                # The .mean() is to reduce in case of distributed training
                loss = outputs[0].mean().item()
                logits = outputs[1:]
1395
1396
            else:
                loss = None
1397
1398
                # Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
                logits = outputs[:]
1399
1400
1401
1402
1403
1404
1405
1406
1407
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        labels = inputs.get("labels")
        if labels is not None:
            labels = labels.detach()
1408
        return (loss, tuple(l.detach() for l in logits), labels)
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
        For models that inherit from :class:`~transformers.PretrainedModel`, uses
        that method to compute the number of floating point operations for every backward + forward pass. If using
        another model, either implement such a method in the model or subclass and override this method.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

        Returns:
            :obj:`int`: The number of floating-point operations.
        """

        if isinstance(self.model, torch.nn.DataParallel) or isinstance(
            self.model, torch.nn.parallel.DistributedDataParallel
        ):
            model = self.model.module
        else:
            model = self.model

        if hasattr(model, "floating_point_ops"):
            return model.floating_point_ops(inputs)

        else:
            return 0