trainer.py 62.4 KB
Newer Older
1
import inspect
Julien Chaumond's avatar
Julien Chaumond committed
2
import logging
3
import math
Julien Chaumond's avatar
Julien Chaumond committed
4
5
6
import os
import re
import shutil
7
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
8
9
from contextlib import contextmanager
from pathlib import Path
10
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
11
12
13

import numpy as np
import torch
14
from packaging import version
Julien Chaumond's avatar
Julien Chaumond committed
15
16
17
18
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
19
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
20
from tqdm.auto import tqdm, trange
Julien Chaumond's avatar
Julien Chaumond committed
21

22
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
23
from .file_utils import is_nlp_available, is_torch_tpu_available
24
25
26
27
28
29
30
31
from .integrations import (
    default_hp_search_backend,
    is_comet_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
)
Julien Chaumond's avatar
Julien Chaumond committed
32
33
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
34
from .tokenization_utils_base import PreTrainedTokenizerBase
35
36
37
38
39
40
41
42
43
44
45
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    set_seed,
)
Patrick von Platen's avatar
Patrick von Platen committed
46
from .training_args import TrainingArguments
Julien Chaumond's avatar
Julien Chaumond committed
47
48


49
50
51
52
53
54
55
56
57
58
59
60
61
_use_native_amp = False
_use_apex = False

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
    from transformers.file_utils import is_apex_available

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
62

63
64
if is_nlp_available():
    import nlp
Julien Chaumond's avatar
Julien Chaumond committed
65

66
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
67
68
69
70
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

71
if is_tensorboard_available():
Julien Chaumond's avatar
Julien Chaumond committed
72
    try:
73
        from torch.utils.tensorboard import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
74
    except ImportError:
75
        from tensorboardX import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
76

77
if is_wandb_available():
78
79
    import wandb

80
81
if is_comet_available():
    import comet_ml
82

83
84
85
86
87
88
if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

Julien Chaumond's avatar
Julien Chaumond committed
89
90
91
92
93
94
logger = logging.getLogger(__name__)


@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
95
    Decorator to make all processes in distributed training wait for each local_master to do something.
96
97
98

    Args:
        local_rank (:obj:`int`): The rank of the local process.
Julien Chaumond's avatar
Julien Chaumond committed
99
100
101
102
103
104
105
106
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
Teven's avatar
Teven committed
139
140
141
        assert (
            len(indices) == self.total_size
        ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
142
143
144

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
Teven's avatar
Teven committed
145
146
147
        assert (
            len(indices) == self.num_samples
        ), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
148
149
150
151
152
153
154

        return iter(indices)

    def __len__(self):
        return self.num_samples


Lysandre Debut's avatar
Lysandre Debut committed
155
156
157
158
159
160
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


Julien Chaumond's avatar
Julien Chaumond committed
161
162
163
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
164
165
166
    optimized for 馃 Transformers.

    Args:
167
168
169
170
171
        model (:class:`~transformers.PreTrainedModel`, `optional`):
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
        args (:class:`~transformers.TrainingArguments`, `optional`):
            The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
            with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
172
        data_collator (:obj:`DataCollator`, `optional`):
173
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or
174
175
            :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
            provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
Sylvain Gugger's avatar
Sylvain Gugger committed
176
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
177
178
            The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
            ``model.forward()`` method are automatically removed.
Sylvain Gugger's avatar
Sylvain Gugger committed
179
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
180
             The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
181
            ``model.forward()`` method are automatically removed.
182
183
184
185
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
186
187
188
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
189
190
191
192
193
194
195
196
197
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
198
199
        kwargs:
            Deprecated keyword arguments.
Julien Chaumond's avatar
Julien Chaumond committed
200
201
202
203
    """

    def __init__(
        self,
204
205
        model: PreTrainedModel = None,
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
206
207
208
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
209
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
210
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
211
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
212
        tb_writer: Optional["SummaryWriter"] = None,
213
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
214
        **kwargs,
Julien Chaumond's avatar
Julien Chaumond committed
215
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
218
219
220
221
        if args is None:
            logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
            args = TrainingArguments("tmp_trainer")
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
222
223
224
225
226
227
        assert (
            model is not None or model_init is not None
        ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
        if model is None and model_init is not None:
            model = model_init()
        self.model = model.to(args.device) if model is not None else None
228
229
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
230
231
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
232
        self.tokenizer = tokenizer
233
        self.model_init = model_init
Julien Chaumond's avatar
Julien Chaumond committed
234
        self.compute_metrics = compute_metrics
235
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
236
237
238
239
240
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
241
        self.tb_writer = tb_writer
242
243
244
245
246
247
248
249
        if "prediction_loss_only" in kwargs:
            warnings.warn(
                "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
                FutureWarning,
            )
            self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

250
        if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
251
252
253
254
255
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
256
        if is_wandb_available():
257
            self.setup_wandb()
258
        elif os.environ.get("WANDB_DISABLED") != "true":
259
            logger.info(
260
261
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
262
            )
263
264
265
266
267
268
269
        if is_comet_available():
            self.setup_comet()
        elif os.environ.get("COMET_MODE") != "DISABLED":
            logger.info(
                "To use comet_ml logging, run `pip/conda install comet_ml` "
                "see https://www.comet.ml/docs/python-sdk/huggingface/"
            )
Julien Chaumond's avatar
Julien Chaumond committed
270
        # Create output directory if needed
271
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
272
            os.makedirs(self.args.output_dir, exist_ok=True)
273
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
274
275
276
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
277
278
279
280
281
282
283
284
285
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
286
287
288
289
290
291
292

        if is_nlp_available():
            if isinstance(train_dataset, nlp.Dataset):
                self._remove_unused_columns(self.train_dataset, description="training")
            if isinstance(eval_dataset, nlp.Dataset):
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

293
294
        self.global_step = None
        self.epoch = None
295
296
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()
297
        self.hp_search_backend = None
Julien Chaumond's avatar
Julien Chaumond committed
298

299
    def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
300
301
        if not self.args.remove_unused_columns:
            return
302
303
304
305
306
307
308
309
310
311
312
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
sgugger's avatar
sgugger committed
313
        dataset.set_format(type=dataset.format["type"], columns=columns)
314

315
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
316
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
317
            return None
318
        elif is_torch_tpu_available():
319
            return get_tpu_sampler(self.train_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
320
        else:
321
            return (
Lysandre Debut's avatar
Lysandre Debut committed
322
323
324
325
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
        (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
341
342
343
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
344
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
345
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
346
347
        )

348
349
350
351
352
353
354
355
356
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
357

Julien Chaumond's avatar
Julien Chaumond committed
358
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
359
360
361
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

362
363
364
365
366
        Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

367
        Args:
368
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
369
370
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`nlp.Dataset`, columns not
                accepted by the ``model.forward()`` method are automatically removed.
371
        """
Julien Chaumond's avatar
Julien Chaumond committed
372
373
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
374
375
        elif eval_dataset is not None and is_nlp_available() and isinstance(eval_dataset, nlp.Dataset):
            self._remove_unused_columns(eval_dataset, description="evaluation")
376
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
377
        eval_sampler = self._get_eval_sampler(eval_dataset)
378

379
        return DataLoader(
380
            eval_dataset,
381
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
382
            batch_size=self.args.eval_batch_size,
383
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
384
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
385
386
387
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
388
389
390
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

391
392
393
394
395
        Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

396
        Args:
397
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
398
399
                The test dataset to use. If it is an :obj:`nlp.Dataset`, columns not accepted by the
                ``model.forward()`` method are automatically removed.
400
        """
401
402
        if is_nlp_available() and isinstance(test_dataset, nlp.Dataset):
            self._remove_unused_columns(test_dataset, description="test")
403
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
404

405
406
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
407
            test_dataset,
408
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
409
            batch_size=self.args.eval_batch_size,
410
            collate_fn=self.data_collator,
411
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
412
        )
Lysandre Debut's avatar
Lysandre Debut committed
413

414
    def create_optimizer_and_scheduler(self, num_training_steps: int):
415
416
417
        """
        Setup the optimizer and the learning rate scheduler.

418
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
419
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
420
        """
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
Julien Chaumond's avatar
Julien Chaumond committed
443

444
    def setup_wandb(self):
445
446
447
        """
        Setup the optional Weights & Biases (`wandb`) integration.

448
449
        One can subclass and override this method to customize the setup if needed. Find more information
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
450
451
452
453
454
455
456
457
458

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
459
        """
460
461
462
463
464
465
466
        if hasattr(self, "_setup_wandb"):
            warnings.warn(
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
                FutureWarning,
            )
            return self._setup_wandb()

467
        if self.is_world_process_zero():
468
469
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
470
            )
471
472
473
474
            combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
            )
475
476
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
477
478
479
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )
480

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    def setup_comet(self):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE:
                (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME:
                (Optional): str - Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY:
                (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"

        For a number of configurable items in the environment,
        see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
        """
        if self.is_world_master():
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(self.model, framework="transformers")
                experiment._log_parameters(self.args, prefix="args/", framework="transformers")
                experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")

512
    def num_examples(self, dataloader: DataLoader) -> int:
513
        """
514
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
515
        """
516
        return len(dataloader.dataset)
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
        if self.hp_search_backend is None or trial is None:
            return
        params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
        self.objective = self.compute_objective(metrics)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
            tune.report(objective=self.objective, **metrics)

    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
550
551
552
553
        """
        Main training entry point.

        Args:
554
555
556
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
557
558
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
Julien Chaumond's avatar
Julien Chaumond committed
559
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
560
561
562
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

563
564
        # Model re-init
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
565
566
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
567
568
569
            model = self.model_init()
            self.model = model.to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
570
571
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
572
573

        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
574
575
576
577
578
579
580
581
582
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
583
            self.args.max_steps = t_total
Julien Chaumond's avatar
Julien Chaumond committed
584

585
        self.create_optimizer_and_scheduler(num_training_steps=t_total)
Julien Chaumond's avatar
Julien Chaumond committed
586
587
588
589
590
591
592
593

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
594
            self.optimizer.load_state_dict(
595
596
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
597
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
Julien Chaumond's avatar
Julien Chaumond committed
598
599

        model = self.model
600
        if self.args.fp16 and _use_apex:
Julien Chaumond's avatar
Julien Chaumond committed
601
602
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
603
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Julien Chaumond's avatar
Julien Chaumond committed
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
620
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
Julien Chaumond's avatar
Julien Chaumond committed
621
622

        # Train!
623
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
624
625
626
627
628
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
629
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
630
            )
Julien Chaumond's avatar
Julien Chaumond committed
631
        logger.info("***** Running training *****")
632
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
Julien Chaumond's avatar
Julien Chaumond committed
633
        logger.info("  Num Epochs = %d", num_train_epochs)
634
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
Lysandre Debut's avatar
Lysandre Debut committed
635
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
636
637
638
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

639
640
        self.global_step = 0
        self.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
641
642
643
644
645
646
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
647
648
649
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
Julien Chaumond's avatar
Julien Chaumond committed
650
651
652
653
654
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
655
                logger.info("  Continuing training from global step %d", self.global_step)
Julien Chaumond's avatar
Julien Chaumond committed
656
657
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
658
                self.global_step = 0
Julien Chaumond's avatar
Julien Chaumond committed
659
660
661
662
663
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
664
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
665
666
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
667
668
669
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

670
            if is_torch_tpu_available():
671
672
673
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
674
                epoch_iterator = parallel_loader
675
            else:
676
                epoch_iterator = train_dataloader
677

678
679
680
681
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

682
            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
Julien Chaumond's avatar
Julien Chaumond committed
683
684
685
686
687
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
688
                    epoch_pbar.update(1)
Julien Chaumond's avatar
Julien Chaumond committed
689
690
                    continue

691
                tr_loss += self.training_step(model, inputs)
Julien Chaumond's avatar
Julien Chaumond committed
692
693
694
695
696
697

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
698
                    if self.args.fp16 and _use_native_amp:
699
                        self.scaler.unscale_(self.optimizer)
700
701
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
702
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
Julien Chaumond's avatar
Julien Chaumond committed
703
704
705
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

706
                    if is_torch_tpu_available():
707
                        xm.optimizer_step(self.optimizer)
708
                    elif self.args.fp16 and _use_native_amp:
709
                        self.scaler.step(self.optimizer)
710
                        self.scaler.update()
Lysandre Debut's avatar
Lysandre Debut committed
711
                    else:
712
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
713

714
                    self.lr_scheduler.step()
Julien Chaumond's avatar
Julien Chaumond committed
715
                    model.zero_grad()
716
717
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)
Julien Chaumond's avatar
Julien Chaumond committed
718

719
720
721
722
723
                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
724
725
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
726
                            self.lr_scheduler.get_last_lr()[0]
727
                            if version.parse(torch.__version__) >= version.parse("1.4")
728
                            else self.lr_scheduler.get_lr()[0]
729
                        )
730
731
                        logging_loss = tr_loss

732
                        self.log(logs)
733

734
                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
735
736
                        metrics = self.evaluate()
                        self._report_to_hp_search(trial, epoch, metrics)
737

738
739
740
741
                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
Teven's avatar
Teven committed
742
743
744
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
745
                        else:
Teven's avatar
Teven committed
746
                            assert model is self.model, f"Model {model} should be a reference to self.model"
747
                        # Save model checkpoint
748
749
750
751
752
753
754
755
756
                        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
                        if self.hp_search_backend is not None and trial is not None:
                            run_id = (
                                trial.number
                                if self.hp_search_backend == HPSearchBackend.OPTUNA
                                else tune.get_trial_id()
                            )
                            checkpoint_folder += f"-run-{run_id}"
                        output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
757
758
759

                        self.save_model(output_dir)

760
                        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
761
                            self._rotate_checkpoints()
762

763
                        if is_torch_tpu_available():
764
                            xm.rendezvous("saving_optimizer_states")
765
766
767
768
769
                            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_process_zero():
                            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
Julien Chaumond's avatar
Julien Chaumond committed
770

771
                epoch_pbar.update(1)
Sylvain Gugger's avatar
Sylvain Gugger committed
772
                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
773
                    break
774
775
            epoch_pbar.close()
            train_pbar.update(1)
Sylvain Gugger's avatar
Sylvain Gugger committed
776
            if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
777
                break
778
            if self.args.tpu_metrics_debug or self.args.debug:
779
780
781
782
783
784
785
786
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
Julien Chaumond's avatar
Julien Chaumond committed
787

788
        train_pbar.close()
Julien Chaumond's avatar
Julien Chaumond committed
789
790
        if self.tb_writer:
            self.tb_writer.close()
791
792
793
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
794
795

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
796
797
        return TrainOutput(self.global_step, tr_loss / self.global_step)

798
799
800
801
802
803
804
805
806
807
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
        **kwargs
    ) -> BestRun:
        """
808
809
810
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
        :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
811

Sylvain Gugger's avatar
Sylvain Gugger committed
812
813
814
815
816
817
818
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

840
                - the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
                - the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__

        Returns:
            :class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
856
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
857
858
        if backend == HPSearchBackend.RAY and not is_ray_available():
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
859
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
860
861
862
            )
        self.hp_search_backend = backend

Sylvain Gugger's avatar
Sylvain Gugger committed
863
864
865
866
867
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

868
869
870
871
872
873
874
875
876
877
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

        def _objective(trial):
            self.objective = None
            self.train(trial=trial)
            # If there hasn't been any evaluation during the training loop.
            if getattr(self, "objective", None) is None:
                metrics = self.evaluate()
                self.objective = self.compute_objective(metrics)
878
879
                if self.hp_search_backend == HPSearchBackend.RAY:
                    tune.report(objective=self.objective)
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
            return self.objective

        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            timeout = kwargs.pop("timeout", None)
            n_jobs = kwargs.pop("n_jobs", 1)
            study = optuna.create_study(direction=direction, **kwargs)
            study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
            best_trial = study.best_trial
            best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            # The TensorBoard writer does not pickle so we have to remove it (if it exists) while doing the ray hp
            # search.
            _tb_writer = self.tb_writer
            self.tb_writer = None
            # Setup default `resources_per_trial` and `reporter`.
            if "resources_per_trial" not in kwargs and self.args.n_gpu > 0:
                kwargs["resources_per_trial"] = {"gpu": self.args.n_gpu}
            if "reporter" not in kwargs:
                from ray.tune import CLIReporter

                kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
            analysis = tune.run(_objective, config=self.hp_space(None), num_samples=n_trials, **kwargs)
            best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
            best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
            self.tb_writer = _tb_writer

        self.hp_search_backend = None
        return best_run

909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
    def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

928
929
        if self.epoch is not None:
            logs["epoch"] = self.epoch
930
931
932
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
933
934
        if self.tb_writer:
            for k, v in logs.items():
935
936
937
938
939
940
941
942
943
944
945
946
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
947
            self.tb_writer.flush()
948
        if is_wandb_available():
949
            if self.is_world_process_zero():
950
                wandb.log(logs, step=self.global_step)
951
952
953
954
955
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
956
        output = {**logs, **{"step": self.global_step}}
957
958
959
        if iterator is not None:
            iterator.write(output)
        else:
960
            print(output)
Julien Chaumond's avatar
Julien Chaumond committed
961

sgugger's avatar
Fix CI  
sgugger committed
962
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
963
964
965
966
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
967
        for k, v in inputs.items():
968
969
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
970

971
972
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
973

974
975
        return inputs

976
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float:
977
        """
978
        Perform a training step on a batch of inputs.
979
980
981
982
983
984
985
986
987
988
989
990
991

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
992
            :obj:`float`: The training loss on this batch.
993
994
995
996
997
998
        """
        if hasattr(self, "_training_step"):
            warnings.warn(
                "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
                FutureWarning,
            )
999
            return self._training_step(model, inputs, self.optimizer)
1000
1001

        model.train()
1002
        inputs = self._prepare_inputs(inputs)
1003

1004
1005
1006
1007
1008
1009
1010
1011
        if self.args.fp16 and _use_native_amp:
            with autocast():
                outputs = model(**inputs)
                loss = outputs[0]
        else:
            outputs = model(**inputs)
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs[0]
Julien Chaumond's avatar
Julien Chaumond committed
1012

1013
1014
1015
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

Julien Chaumond's avatar
Julien Chaumond committed
1016
1017
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1018

Julien Chaumond's avatar
Julien Chaumond committed
1019
1020
1021
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

1022
1023
1024
        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
1025
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1026
1027
1028
1029
1030
1031
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

Lysandre Debut's avatar
Lysandre Debut committed
1032
    def is_local_master(self) -> bool:
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
        return self.is_local_process_zero()

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
1049
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1050
1051
1052
1053
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

Julien Chaumond's avatar
Julien Chaumond committed
1054
1055
    def is_world_master(self) -> bool:
        """
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
        return self.is_world_process_zero()

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1070
        """
1071
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1072
1073
1074
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
1075
1076
1077

    def save_model(self, output_dir: Optional[str] = None):
        """
1078
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1079

1080
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
1081
        """
1082

1083
        if is_torch_tpu_available():
1084
            self._save_tpu(output_dir)
1085
        elif self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
1086
1087
            self._save(output_dir)

1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)
1103
1104
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
1105

Julien Chaumond's avatar
Julien Chaumond committed
1106
1107
1108
1109
1110
1111
1112
1113
1114
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)
1115
1116
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1117
1118
1119
1120
1121
1122
1123

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

1124
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1125
1126
1127
1128
1129

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1130
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1131
1132
1133
1134
1135
1136
1137
1138
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
1139
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

1153
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1154
        """
1155
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1156
1157

        The calling script will be responsible for providing a method to compute metrics, as they are
1158
        task-dependent (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1159

1160
1161
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1162
        Args:
1163
            eval_dataset (:obj:`Dataset`, `optional`):
1164
1165
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`nlp.Dataset`,
                columns not accepted by the ``model.forward()`` method are automatically removed.
1166

Julien Chaumond's avatar
Julien Chaumond committed
1167
        Returns:
1168
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
Julien Chaumond's avatar
Julien Chaumond committed
1169
1170
1171
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

1172
        output = self.prediction_loop(eval_dataloader, description="Evaluation")
Lysandre Debut's avatar
Lysandre Debut committed
1173

1174
        self.log(output.metrics)
1175

1176
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1177
1178
1179
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Julien Chaumond's avatar
Julien Chaumond committed
1180
1181
1182
1183
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
1184
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1185
1186

        Depending on the dataset and your use case, your test dataset may contain labels.
1187
1188
1189
1190
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
1191
1192
                Dataset to run the predictions on. If it is an :obj:`nlp.Dataset`, columns not accepted by the
                ``model.forward()`` method are automatically removed.
1193

1194
1195
1196
1197
1198
1199
1200
1201
        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1202
1203
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
1204

1205
        return self.prediction_loop(test_dataloader, description="Prediction")
Julien Chaumond's avatar
Julien Chaumond committed
1206

1207
    def prediction_loop(
Julien Chaumond's avatar
Julien Chaumond committed
1208
1209
1210
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
1211
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1212
1213
1214

        Works both with or without labels.
        """
1215
1216
1217
1218
1219
1220
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
Julien Chaumond's avatar
Julien Chaumond committed
1221

1222
1223
1224
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1225

1226
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
1227
        # multi-gpu eval
1228
1229
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
Julien Chaumond's avatar
Julien Chaumond committed
1230
1231
        else:
            model = self.model
1232
1233
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
1234

1235
        batch_size = dataloader.batch_size
Julien Chaumond's avatar
Julien Chaumond committed
1236
        logger.info("***** Running %s *****", description)
1237
1238
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
1239
        eval_losses: List[float] = []
1240
1241
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
Julien Chaumond's avatar
Julien Chaumond committed
1242
1243
        model.eval()

1244
        if is_torch_tpu_available():
1245
1246
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1247
        if self.args.past_index >= 0:
1248
            self._past = None
1249

1250
        disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
Sylvain Gugger's avatar
Sylvain Gugger committed
1251
        samples_count = 0
1252
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
1253
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
1254
1255
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
            samples_count += batch_size
1256
            if loss is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1257
                eval_losses.append(loss * batch_size)
1258
1259
1260
1261
            if logits is not None:
                preds = logits if preds is None else torch.cat((preds, logits), dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
1262

1263
1264
1265
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1266

1267
1268
1269
1270
1271
1272
        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
1273
        elif is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1274
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()
Lysandre Debut's avatar
Lysandre Debut committed
1285

Julien Chaumond's avatar
Julien Chaumond committed
1286
1287
1288
1289
1290
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
Sylvain Gugger's avatar
Sylvain Gugger committed
1291
            metrics["eval_loss"] = np.sum(eval_losses) / samples_count
1292
1293
1294
1295
1296

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
1297
1298

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

1337
        inputs = self._prepare_inputs(inputs)
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
                loss, logits = outputs[:2]
                loss = loss.mean().item()
            else:
                loss = None
                logits = outputs[0]
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        labels = inputs.get("labels")
        if labels is not None:
            labels = labels.detach()
        return (loss, logits.detach(), labels)