trainer.py 62.2 KB
Newer Older
1
import inspect
2
import math
Julien Chaumond's avatar
Julien Chaumond committed
3
4
5
import os
import re
import shutil
6
import warnings
Julien Chaumond's avatar
Julien Chaumond committed
7
8
from contextlib import contextmanager
from pathlib import Path
9
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Julien Chaumond's avatar
Julien Chaumond committed
10
11
12

import numpy as np
import torch
13
from packaging import version
Julien Chaumond's avatar
Julien Chaumond committed
14
15
16
17
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
18
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
19
from tqdm.auto import tqdm, trange
Julien Chaumond's avatar
Julien Chaumond committed
20

21
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
22
from .file_utils import is_nlp_available, is_torch_tpu_available
23
24
25
26
27
28
29
from .integrations import (
    default_hp_search_backend,
    is_comet_available,
    is_optuna_available,
    is_ray_available,
    is_tensorboard_available,
    is_wandb_available,
30
31
    run_hp_search_optuna,
    run_hp_search_ray,
32
)
Julien Chaumond's avatar
Julien Chaumond committed
33
34
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
35
from .tokenization_utils_base import PreTrainedTokenizerBase
36
37
38
39
40
41
42
43
44
45
46
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    set_seed,
)
Patrick von Platen's avatar
Patrick von Platen committed
47
from .training_args import TrainingArguments
Lysandre Debut's avatar
Lysandre Debut committed
48
from .utils import logging
Julien Chaumond's avatar
Julien Chaumond committed
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
_use_native_amp = False
_use_apex = False

# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
    from transformers.file_utils import is_apex_available

    if is_apex_available():
        from apex import amp
    _use_apex = True
else:
    _use_native_amp = True
    from torch.cuda.amp import autocast
Julien Chaumond's avatar
Julien Chaumond committed
64

65
66
if is_nlp_available():
    import nlp
Julien Chaumond's avatar
Julien Chaumond committed
67

68
if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
69
70
71
72
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

73
if is_tensorboard_available():
Julien Chaumond's avatar
Julien Chaumond committed
74
    try:
75
        from torch.utils.tensorboard import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
76
    except ImportError:
77
        from tensorboardX import SummaryWriter
Julien Chaumond's avatar
Julien Chaumond committed
78

79
if is_wandb_available():
80
81
    import wandb

82
83
if is_comet_available():
    import comet_ml
84

85
86
87
88
89
90
if is_optuna_available():
    import optuna

if is_ray_available():
    from ray import tune

Lysandre Debut's avatar
Lysandre Debut committed
91
logger = logging.get_logger(__name__)
Julien Chaumond's avatar
Julien Chaumond committed
92
93
94
95
96


@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
97
    Decorator to make all processes in distributed training wait for each local_master to do something.
98
99
100

    Args:
        local_rank (:obj:`int`): The rank of the local process.
Julien Chaumond's avatar
Julien Chaumond committed
101
102
103
104
105
106
107
108
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()


109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class SequentialDistributedSampler(Sampler):
    """
    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.

    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """

    def __init__(self, dataset, num_replicas=None, rank=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
Teven's avatar
Teven committed
141
142
143
        assert (
            len(indices) == self.total_size
        ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
144
145
146

        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
Teven's avatar
Teven committed
147
148
149
        assert (
            len(indices) == self.num_samples
        ), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
150
151
152
153
154
155
156

        return iter(indices)

    def __len__(self):
        return self.num_samples


Lysandre Debut's avatar
Lysandre Debut committed
157
158
159
160
161
162
def get_tpu_sampler(dataset: Dataset):
    if xm.xrt_world_size() <= 1:
        return RandomSampler(dataset)
    return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


Julien Chaumond's avatar
Julien Chaumond committed
163
164
165
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
166
167
168
    optimized for 馃 Transformers.

    Args:
169
170
171
172
173
        model (:class:`~transformers.PreTrainedModel`, `optional`):
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
        args (:class:`~transformers.TrainingArguments`, `optional`):
            The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
            with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
174
        data_collator (:obj:`DataCollator`, `optional`):
175
            The function to use to form a batch from a list of elements of :obj:`train_dataset` or
176
177
            :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
            provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
Sylvain Gugger's avatar
Sylvain Gugger committed
178
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
179
180
            The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
            ``model.forward()`` method are automatically removed.
Sylvain Gugger's avatar
Sylvain Gugger committed
181
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
182
             The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
183
            ``model.forward()`` method are automatically removed.
184
185
186
187
        tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
            The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
            maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
            interrupted training or reuse the fine-tuned model.
188
189
190
        model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
            A function that instantiates the model to be used. If provided, each call to
            :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
191
192
193
194
195
196
197
198
199
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
200
201
        kwargs:
            Deprecated keyword arguments.
Julien Chaumond's avatar
Julien Chaumond committed
202
203
204
205
    """

    def __init__(
        self,
206
207
        model: PreTrainedModel = None,
        args: TrainingArguments = None,
Julien Chaumond's avatar
Julien Chaumond committed
208
209
210
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
211
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
212
        model_init: Callable[[], PreTrainedModel] = None,
Julien Chaumond's avatar
Julien Chaumond committed
213
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
214
        tb_writer: Optional["SummaryWriter"] = None,
215
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
216
        **kwargs,
Julien Chaumond's avatar
Julien Chaumond committed
217
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
218
219
220
221
222
223
        if args is None:
            logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
            args = TrainingArguments("tmp_trainer")
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
224
225
226
227
228
229
        assert (
            model is not None or model_init is not None
        ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
        if model is None and model_init is not None:
            model = model_init()
        self.model = model.to(args.device) if model is not None else None
230
231
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
Julien Chaumond's avatar
Julien Chaumond committed
232
233
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
234
        self.tokenizer = tokenizer
235
        self.model_init = model_init
Julien Chaumond's avatar
Julien Chaumond committed
236
        self.compute_metrics = compute_metrics
237
        self.optimizer, self.lr_scheduler = optimizers
Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
240
241
242
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument."
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )
243
        self.tb_writer = tb_writer
244
245
246
247
248
249
250
251
        if "prediction_loss_only" in kwargs:
            warnings.warn(
                "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
                FutureWarning,
            )
            self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

252
        if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
253
254
255
256
257
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
258
259
260
261

        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

Julien Chaumond's avatar
Julien Chaumond committed
262
        # Create output directory if needed
263
        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
264
            os.makedirs(self.args.output_dir, exist_ok=True)
265
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
266
267
268
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
269
270
271
272
273
274
275
276
277
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
278
279
280
281
282
283
284

        if is_nlp_available():
            if isinstance(train_dataset, nlp.Dataset):
                self._remove_unused_columns(self.train_dataset, description="training")
            if isinstance(eval_dataset, nlp.Dataset):
                self._remove_unused_columns(self.eval_dataset, description="evaluation")

285
286
        self.global_step = None
        self.epoch = None
287
288
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()
289
        self.hp_search_backend = None
290
        self.use_tune_checkpoints = False
Julien Chaumond's avatar
Julien Chaumond committed
291

292
    def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
293
294
        if not self.args.remove_unused_columns:
            return
295
296
297
298
299
300
301
302
303
304
305
        # Inspect model forward signature to keep only the arguments it accepts.
        signature = inspect.signature(self.model.forward)
        signature_columns = list(signature.parameters.keys())
        # Labels may be named label or label_ids, the default data collator handles that.
        signature_columns += ["label", "label_ids"]
        columns = [k for k in signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        dset_description = "" if description is None else f"in the {description} set "
        logger.info(
            f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
        )
sgugger's avatar
sgugger committed
306
        dataset.set_format(type=dataset.format["type"], columns=columns)
307

308
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
309
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
310
            return None
311
        elif is_torch_tpu_available():
312
            return get_tpu_sampler(self.train_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
313
        else:
314
            return (
Lysandre Debut's avatar
Lysandre Debut committed
315
316
317
318
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
        (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
334
335
336
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
337
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
338
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
339
340
        )

341
342
343
344
345
346
347
348
349
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
350

Julien Chaumond's avatar
Julien Chaumond committed
351
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
352
353
354
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

355
356
357
358
359
        Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

360
        Args:
361
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
362
363
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`nlp.Dataset`, columns not
                accepted by the ``model.forward()`` method are automatically removed.
364
        """
Julien Chaumond's avatar
Julien Chaumond committed
365
366
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
367
368
        elif eval_dataset is not None and is_nlp_available() and isinstance(eval_dataset, nlp.Dataset):
            self._remove_unused_columns(eval_dataset, description="evaluation")
369
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
370
        eval_sampler = self._get_eval_sampler(eval_dataset)
371

372
        return DataLoader(
373
            eval_dataset,
374
            sampler=eval_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
375
            batch_size=self.args.eval_batch_size,
376
            collate_fn=self.data_collator,
Setu Shah's avatar
Setu Shah committed
377
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
378
379
380
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
381
382
383
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

384
385
386
387
388
        Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

389
        Args:
390
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
391
392
                The test dataset to use. If it is an :obj:`nlp.Dataset`, columns not accepted by the
                ``model.forward()`` method are automatically removed.
393
        """
394
395
        if is_nlp_available() and isinstance(test_dataset, nlp.Dataset):
            self._remove_unused_columns(test_dataset, description="test")
396
        test_sampler = self._get_eval_sampler(test_dataset)
Lysandre Debut's avatar
Lysandre Debut committed
397

398
399
        # We use the same batch_size as for eval.
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
400
            test_dataset,
401
            sampler=test_sampler,
Julien Chaumond's avatar
Julien Chaumond committed
402
            batch_size=self.args.eval_batch_size,
403
            collate_fn=self.data_collator,
404
            drop_last=self.args.dataloader_drop_last,
Julien Chaumond's avatar
Julien Chaumond committed
405
        )
Lysandre Debut's avatar
Lysandre Debut committed
406

407
    def create_optimizer_and_scheduler(self, num_training_steps: int):
408
409
410
        """
        Setup the optimizer and the learning rate scheduler.

411
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
412
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
413
        """
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )
Julien Chaumond's avatar
Julien Chaumond committed
436

437
    def setup_wandb(self):
438
439
440
        """
        Setup the optional Weights & Biases (`wandb`) integration.

441
442
        One can subclass and override this method to customize the setup if needed. Find more information
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
443
444
445
446
447
448
449
450
451

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
452
        """
453
454
455
456
457
458
459
        if hasattr(self, "_setup_wandb"):
            warnings.warn(
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
                FutureWarning,
            )
            return self._setup_wandb()

460
        if self.is_world_process_zero():
461
462
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
463
            )
464
465
466
467
            combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
            )
468
469
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
470
471
472
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )
473

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    def setup_comet(self):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE:
                (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME:
                (Optional): str - Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY:
                (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"

        For a number of configurable items in the environment,
        see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
        """
        if self.is_world_master():
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(self.model, framework="transformers")
                experiment._log_parameters(self.args, prefix="args/", framework="transformers")
                experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")

505
    def num_examples(self, dataloader: DataLoader) -> int:
506
        """
507
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
508
        """
509
        return len(dataloader.dataset)
510

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    def _setup_loggers(self):
        if self._loggers_initialized:
            return
        if is_wandb_available():
            self.setup_wandb()
        elif os.environ.get("WANDB_DISABLED") != "true":
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        if is_comet_available():
            self.setup_comet()
        elif os.environ.get("COMET_MODE") != "DISABLED":
            logger.info(
                "To use comet_ml logging, run `pip/conda install comet_ml` "
                "see https://www.comet.ml/docs/python-sdk/huggingface/"
            )
        self._loggers_initialized = True

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """ HP search setup code """
        if self.hp_search_backend is None or trial is None:
            return
        params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
        for key, value in params.items():
            if not hasattr(self.args, key):
                raise AttributeError(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
                )
            old_attr = getattr(self.args, key, None)
            # Casting value to the proper type
            if old_attr is not None:
                value = type(old_attr)(value)
            setattr(self.args, key, value)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info("Trial:", trial.params)

    def _report_to_hp_search(
        self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
    ):
        if self.hp_search_backend is None or trial is None:
            return
        self.objective = self.compute_objective(metrics)
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            trial.report(self.objective, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        elif self.hp_search_backend == HPSearchBackend.RAY:
559
560
            if self.global_step % self.args.save_steps == 0:
                self._tune_save_checkpoint()
561
562
            tune.report(objective=self.objective, **metrics)

563
564
565
566
567
568
569
570
571
572
573
    def _tune_save_checkpoint(self):
        if not self.use_tune_checkpoints:
            return
        with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
            self.args.output_dir = checkpoint_dir
            output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
            self.save_model(output_dir)
            if self.is_world_master():
                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

574
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
Julien Chaumond's avatar
Julien Chaumond committed
575
576
577
578
        """
        Main training entry point.

        Args:
579
580
581
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
582
583
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
Julien Chaumond's avatar
Julien Chaumond committed
584
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
585
586
587
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

588
589
        # Model re-init
        if self.model_init is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
590
591
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
592
593
594
            model = self.model_init()
            self.model = model.to(self.args.device)

Sylvain Gugger's avatar
Sylvain Gugger committed
595
596
            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None
597
598

        # Data loader and number of training steps
Julien Chaumond's avatar
Julien Chaumond committed
599
600
601
602
603
604
605
606
607
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
Sylvain Gugger's avatar
Sylvain Gugger committed
608
            self.args.max_steps = t_total
Julien Chaumond's avatar
Julien Chaumond committed
609

610
        self.create_optimizer_and_scheduler(num_training_steps=t_total)
Julien Chaumond's avatar
Julien Chaumond committed
611
612
613
614
615
616
617
618

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
619
            self.optimizer.load_state_dict(
620
621
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
622
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
Julien Chaumond's avatar
Julien Chaumond committed
623
624

        model = self.model
625
        if self.args.fp16 and _use_apex:
Julien Chaumond's avatar
Julien Chaumond committed
626
627
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
628
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Julien Chaumond's avatar
Julien Chaumond committed
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
645
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
Julien Chaumond's avatar
Julien Chaumond committed
646
647

        # Train!
648
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
649
650
651
652
653
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
654
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
Lysandre Debut's avatar
Lysandre Debut committed
655
            )
Julien Chaumond's avatar
Julien Chaumond committed
656
        logger.info("***** Running training *****")
657
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
Julien Chaumond's avatar
Julien Chaumond committed
658
        logger.info("  Num Epochs = %d", num_train_epochs)
659
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
Lysandre Debut's avatar
Lysandre Debut committed
660
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
661
662
663
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

664
665
        self.global_step = 0
        self.epoch = 0
Julien Chaumond's avatar
Julien Chaumond committed
666
667
668
669
670
671
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
672
                self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
673
674
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
Julien Chaumond's avatar
Julien Chaumond committed
675
676
677
678
679
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
680
                logger.info("  Continuing training from global step %d", self.global_step)
Julien Chaumond's avatar
Julien Chaumond committed
681
682
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
683
                self.global_step = 0
Julien Chaumond's avatar
Julien Chaumond committed
684
685
                logger.info("  Starting fine-tuning.")

686
687
        tr_loss = torch.tensor(0.0).to(self.args.device)
        logging_loss_scalar = 0.0
Julien Chaumond's avatar
Julien Chaumond committed
688
        model.zero_grad()
689
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
690
691
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
692
693
694
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

695
            if is_torch_tpu_available():
696
697
698
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
699
                epoch_iterator = parallel_loader
700
            else:
701
                epoch_iterator = train_dataloader
702

703
704
705
706
            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

707
            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
Julien Chaumond's avatar
Julien Chaumond committed
708
709
710
711
712
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
713
                    epoch_pbar.update(1)
Julien Chaumond's avatar
Julien Chaumond committed
714
715
                    continue

716
                tr_loss += self.training_step(model, inputs)
Julien Chaumond's avatar
Julien Chaumond committed
717
718
719
720
721
722

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
723
                    if self.args.fp16 and _use_native_amp:
724
                        self.scaler.unscale_(self.optimizer)
725
726
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
727
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
Julien Chaumond's avatar
Julien Chaumond committed
728
729
730
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

731
                    if is_torch_tpu_available():
732
                        xm.optimizer_step(self.optimizer)
733
                    elif self.args.fp16 and _use_native_amp:
734
                        self.scaler.step(self.optimizer)
735
                        self.scaler.update()
Lysandre Debut's avatar
Lysandre Debut committed
736
                    else:
737
                        self.optimizer.step()
Lysandre Debut's avatar
Lysandre Debut committed
738

739
                    self.lr_scheduler.step()
Julien Chaumond's avatar
Julien Chaumond committed
740
                    model.zero_grad()
741
742
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)
Julien Chaumond's avatar
Julien Chaumond committed
743

744
745
746
747
                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
748
749
                        tr_loss_scalar = tr_loss.item()
                        logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
750
751
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
752
                            self.lr_scheduler.get_last_lr()[0]
753
                            if version.parse(torch.__version__) >= version.parse("1.4")
754
                            else self.lr_scheduler.get_lr()[0]
755
                        )
756
                        logging_loss_scalar = tr_loss_scalar
757

758
                        self.log(logs)
759

760
                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
761
762
                        metrics = self.evaluate()
                        self._report_to_hp_search(trial, epoch, metrics)
763

764
765
766
767
                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
Teven's avatar
Teven committed
768
769
770
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
771
                        else:
Teven's avatar
Teven committed
772
                            assert model is self.model, f"Model {model} should be a reference to self.model"
773
                        # Save model checkpoint
774
775
776
777
778
779
780
781
782
                        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
                        if self.hp_search_backend is not None and trial is not None:
                            run_id = (
                                trial.number
                                if self.hp_search_backend == HPSearchBackend.OPTUNA
                                else tune.get_trial_id()
                            )
                            checkpoint_folder += f"-run-{run_id}"
                        output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
783
784
785

                        self.save_model(output_dir)

786
                        if self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
787
                            self._rotate_checkpoints()
788

789
                        if is_torch_tpu_available():
790
                            xm.rendezvous("saving_optimizer_states")
791
792
793
794
795
                            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_process_zero():
                            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
Julien Chaumond's avatar
Julien Chaumond committed
796

797
                epoch_pbar.update(1)
Sylvain Gugger's avatar
Sylvain Gugger committed
798
                if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
Julien Chaumond's avatar
Julien Chaumond committed
799
                    break
800
801
            epoch_pbar.close()
            train_pbar.update(1)
802
            if self.args.tpu_metrics_debug or self.args.debug:
803
804
805
806
807
808
809
810
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
811
812
            if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
                break
Julien Chaumond's avatar
Julien Chaumond committed
813

814
        train_pbar.close()
Julien Chaumond's avatar
Julien Chaumond committed
815
816
        if self.tb_writer:
            self.tb_writer.close()
817
818
819
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
820
821

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
822
        return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
823

824
825
826
827
828
829
830
831
832
833
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: str = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
        **kwargs
    ) -> BestRun:
        """
834
835
836
        Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
        :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
        the sum of all metrics otherwise.
837

Sylvain Gugger's avatar
Sylvain Gugger committed
838
839
840
841
842
843
844
        .. warning::

            To use this method, you need to have provided a ``model_init`` when initializing your
            :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
            with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
            method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.

845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
        Args:
            hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
                A function that defines the hyperparameter search space. Will default to
                :func:`~transformers.trainer_utils.default_hp_space_optuna` or
                :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
            compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
                A function computing the objective to minimize or maximize from the metrics returned by the
                :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
            n_trials (:obj:`int`, `optional`, defaults to 100):
                The number of trial runs to test.
            direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
                Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
                pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
                several metrics.
            backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
                The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
                one is installed. If both are installed, will default to optuna.
            kwargs:
                Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
                more information see:

866
                - the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
                - the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__

        Returns:
            :class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
        """
        if backend is None:
            backend = default_hp_search_backend()
            if backend is None:
                raise RuntimeError(
                    "At least one of optuna or ray should be installed. "
                    "To install optuna run `pip install optuna`."
                    "To install ray run `pip install ray[tune]`."
                )
        backend = HPSearchBackend(backend)
        if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
882
            raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
883
884
        if backend == HPSearchBackend.RAY and not is_ray_available():
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
885
                "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
886
887
888
            )
        self.hp_search_backend = backend

Sylvain Gugger's avatar
Sylvain Gugger committed
889
890
891
892
893
        if self.model_init is None:
            raise RuntimeError(
                "To use hyperparameter search, you need to pass your model through a model_init function."
            )

894
895
896
        self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

897
898
        run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
        best_run = run_hp_search(self, n_trials, direction, **kwargs)
899
900
901
902

        self.hp_search_backend = None
        return best_run

903
904
905
906
907
908
909
910
911
912
913
914
    def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
915
916
917
        # Set up loggers like W&B or Comet ML
        self._setup_loggers()

918
919
920
921
922
923
924
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

925
926
        if self.epoch is not None:
            logs["epoch"] = self.epoch
927
928
929
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
930
931
        if self.tb_writer:
            for k, v in logs.items():
932
933
934
935
936
937
938
939
940
941
942
943
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
944
            self.tb_writer.flush()
945
        if is_wandb_available():
946
            if self.is_world_process_zero():
947
                wandb.log(logs, step=self.global_step)
948
949
950
951
952
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
953
        output = {**logs, **{"step": self.global_step}}
954
955
956
        if iterator is not None:
            iterator.write(output)
        else:
957
            print(output)
Julien Chaumond's avatar
Julien Chaumond committed
958

sgugger's avatar
Fix CI  
sgugger committed
959
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
960
961
962
963
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
Julien Chaumond's avatar
Julien Chaumond committed
964
        for k, v in inputs.items():
965
966
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)
Julien Chaumond's avatar
Julien Chaumond committed
967

968
969
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
970

971
972
        return inputs

973
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
974
        """
975
        Perform a training step on a batch of inputs.
976
977
978
979
980
981
982
983
984
985
986
987
988

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
989
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
990
991
992
993
994
995
        """
        if hasattr(self, "_training_step"):
            warnings.warn(
                "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
                FutureWarning,
            )
996
            return self._training_step(model, inputs, self.optimizer)
997
998

        model.train()
999
        inputs = self._prepare_inputs(inputs)
1000

1001
1002
1003
1004
1005
1006
1007
1008
        if self.args.fp16 and _use_native_amp:
            with autocast():
                outputs = model(**inputs)
                loss = outputs[0]
        else:
            outputs = model(**inputs)
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs[0]
Julien Chaumond's avatar
Julien Chaumond committed
1009

1010
1011
1012
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

Julien Chaumond's avatar
Julien Chaumond committed
1013
1014
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
1015

Julien Chaumond's avatar
Julien Chaumond committed
1016
1017
1018
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

1019
1020
1021
        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
1022
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
Julien Chaumond's avatar
Julien Chaumond committed
1023
1024
1025
1026
                scaled_loss.backward()
        else:
            loss.backward()

1027
        return loss
Julien Chaumond's avatar
Julien Chaumond committed
1028

Lysandre Debut's avatar
Lysandre Debut committed
1029
    def is_local_master(self) -> bool:
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
        return self.is_local_process_zero()

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
1046
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1047
1048
1049
1050
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

Julien Chaumond's avatar
Julien Chaumond committed
1051
1052
    def is_world_master(self) -> bool:
        """
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
        return self.is_world_process_zero()

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
Julien Chaumond's avatar
Julien Chaumond committed
1067
        """
1068
        if is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1069
1070
1071
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
Julien Chaumond's avatar
Julien Chaumond committed
1072
1073
1074

    def save_model(self, output_dir: Optional[str] = None):
        """
1075
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
Julien Chaumond's avatar
Julien Chaumond committed
1076

1077
        Will only save from the world_master process (unless in TPUs).
Julien Chaumond's avatar
Julien Chaumond committed
1078
        """
1079

1080
        if is_torch_tpu_available():
1081
            self._save_tpu(output_dir)
1082
        elif self.is_world_process_zero():
Julien Chaumond's avatar
Julien Chaumond committed
1083
1084
            self._save(output_dir)

1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)
1100
1101
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
1102

Julien Chaumond's avatar
Julien Chaumond committed
1103
1104
1105
1106
1107
1108
1109
1110
1111
    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)
1112
1113
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
Julien Chaumond's avatar
Julien Chaumond committed
1114
1115
1116
1117
1118
1119
1120

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

1121
        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
Julien Chaumond's avatar
Julien Chaumond committed
1122
1123
1124
1125
1126

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
1127
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
Julien Chaumond's avatar
Julien Chaumond committed
1128
1129
1130
1131
1132
1133
1134
1135
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
1136
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
Julien Chaumond's avatar
Julien Chaumond committed
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

1150
    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Julien Chaumond's avatar
Julien Chaumond committed
1151
        """
1152
        Run evaluation and returns metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1153
1154

        The calling script will be responsible for providing a method to compute metrics, as they are
1155
        task-dependent (pass it to the init :obj:`compute_metrics` argument).
Julien Chaumond's avatar
Julien Chaumond committed
1156

1157
1158
        You can also subclass and override this method to inject custom behavior.

Julien Chaumond's avatar
Julien Chaumond committed
1159
        Args:
1160
            eval_dataset (:obj:`Dataset`, `optional`):
1161
1162
                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`nlp.Dataset`,
                columns not accepted by the ``model.forward()`` method are automatically removed.
1163

Julien Chaumond's avatar
Julien Chaumond committed
1164
        Returns:
1165
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
Julien Chaumond's avatar
Julien Chaumond committed
1166
1167
1168
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

1169
        output = self.prediction_loop(eval_dataloader, description="Evaluation")
Lysandre Debut's avatar
Lysandre Debut committed
1170

1171
        self.log(output.metrics)
1172

1173
        if self.args.tpu_metrics_debug or self.args.debug:
Lysandre Debut's avatar
Lysandre Debut committed
1174
1175
1176
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

Julien Chaumond's avatar
Julien Chaumond committed
1177
1178
1179
1180
        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
1181
        Run prediction and returns predictions and potential metrics.
Julien Chaumond's avatar
Julien Chaumond committed
1182
1183

        Depending on the dataset and your use case, your test dataset may contain labels.
1184
1185
1186
1187
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
1188
1189
                Dataset to run the predictions on. If it is an :obj:`nlp.Dataset`, columns not accepted by the
                ``model.forward()`` method are automatically removed.
1190

1191
1192
1193
1194
1195
1196
1197
1198
        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
Julien Chaumond's avatar
Julien Chaumond committed
1199
1200
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
1201

1202
        return self.prediction_loop(test_dataloader, description="Prediction")
Julien Chaumond's avatar
Julien Chaumond committed
1203

1204
    def prediction_loop(
Julien Chaumond's avatar
Julien Chaumond committed
1205
1206
1207
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
1208
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Julien Chaumond's avatar
Julien Chaumond committed
1209
1210
1211

        Works both with or without labels.
        """
1212
1213
1214
1215
1216
1217
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
Julien Chaumond's avatar
Julien Chaumond committed
1218

1219
1220
1221
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )
Julien Chaumond's avatar
Julien Chaumond committed
1222

1223
        model = self.model
Julien Chaumond's avatar
Julien Chaumond committed
1224
        # multi-gpu eval
1225
1226
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
Julien Chaumond's avatar
Julien Chaumond committed
1227
1228
        else:
            model = self.model
1229
1230
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
Julien Chaumond's avatar
Julien Chaumond committed
1231

1232
        batch_size = dataloader.batch_size
Julien Chaumond's avatar
Julien Chaumond committed
1233
        logger.info("***** Running %s *****", description)
1234
1235
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
Julien Chaumond's avatar
Julien Chaumond committed
1236
        eval_losses: List[float] = []
1237
1238
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
Julien Chaumond's avatar
Julien Chaumond committed
1239
1240
        model.eval()

1241
        if is_torch_tpu_available():
1242
1243
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

1244
        if self.args.past_index >= 0:
1245
            self._past = None
1246

1247
        disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
Sylvain Gugger's avatar
Sylvain Gugger committed
1248
        samples_count = 0
1249
        for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
1250
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
1251
1252
            batch_size = inputs[list(inputs.keys())[0]].shape[0]
            samples_count += batch_size
1253
            if loss is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1254
                eval_losses.append(loss * batch_size)
1255
1256
1257
1258
            if logits is not None:
                preds = logits if preds is None else torch.cat((preds, logits), dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)
Julien Chaumond's avatar
Julien Chaumond committed
1259

1260
1261
1262
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")
Julien Chaumond's avatar
Julien Chaumond committed
1263

1264
1265
1266
1267
1268
1269
        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
1270
        elif is_torch_tpu_available():
Lysandre Debut's avatar
Lysandre Debut committed
1271
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
1272
1273
1274
1275
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
1276
1277
1278
1279
            if eval_losses is not None:
                eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
            if samples_count is not None:
                samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())
1280
1281
1282
1283
1284
1285

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()
Lysandre Debut's avatar
Lysandre Debut committed
1286

Julien Chaumond's avatar
Julien Chaumond committed
1287
1288
1289
1290
1291
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
Sylvain Gugger's avatar
Sylvain Gugger committed
1292
            metrics["eval_loss"] = np.sum(eval_losses) / samples_count
1293
1294
1295
1296
1297

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)
Julien Chaumond's avatar
Julien Chaumond committed
1298
1299

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

1338
        inputs = self._prepare_inputs(inputs)
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
                loss, logits = outputs[:2]
                loss = loss.mean().item()
            else:
                loss = None
                logits = outputs[0]
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        labels = inputs.get("labels")
        if labels is not None:
            labels = labels.detach()
        return (loss, logits.detach(), labels)