Unverified Commit 08ba4b49 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer callbacks (#7596)



* Initial callback proposal

* Finish various callbacks

* Post-rebase conflicts

* Fix tests

* Don't use something that's not set

* Documentation

* Remove unwanted print.

* Document all models can work

* Add tests + small fixes

* Update docs/source/internal/trainer_utils.rst
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Address review comments

* Fix TF tests

* Real fix this time

* This one should work

* Fix typo

* Really fix typo
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 8fa0c956
...@@ -213,6 +213,7 @@ conversion utilities for the following models: ...@@ -213,6 +213,7 @@ conversion utilities for the following models:
:maxdepth: 2 :maxdepth: 2
:caption: Main Classes :caption: Main Classes
main_classes/callback
main_classes/configuration main_classes/configuration
main_classes/logging main_classes/logging
main_classes/model main_classes/model
...@@ -270,3 +271,4 @@ conversion utilities for the following models: ...@@ -270,3 +271,4 @@ conversion utilities for the following models:
internal/modeling_utils internal/modeling_utils
internal/pipelines_utils internal/pipelines_utils
internal/tokenization_utils internal/tokenization_utils
internal/trainer_utils
Utilities for Trainer
-----------------------------------------------------------------------------------------------------------------------
This page lists all the utility functions used by :class:`~transformers.Trainer`.
Most of those are only useful if you are studying the code of the Trainer in the library.
Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EvalPrediction
.. autofunction:: transformers.set_seed
.. autofunction:: transformers.torch_distributed_zero_first
Callbacks internals
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.trainer_callback.CallbackHandler
Callbacks
-----------------------------------------------------------------------------------------------------------------------
Callbacks are objects that can customize the behavior of the training loop in the PyTorch
:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
stopping).
Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).
By default a :class:`~transformers.Trainer` will use the following callbacks:
- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation.
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
it's the second one).
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
or tensorboardX).
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
:class:`~transformers.TrainerControl`.
Available Callbacks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here is the list of the available :class:`~transformers.TrainerCallback` in the library:
.. autoclass:: transformers.integrations.CometCallback
:members: setup
.. autoclass:: transformers.DefaultFlowCallback
.. autoclass:: transformers.PrinterCallback
.. autoclass:: transformers.ProgressCallback
.. autoclass:: transformers.integrations.TensorBoardCallback
.. autoclass:: transformers.integrations.WandbCallback
:members: setup
TrainerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerCallback
:members:
TrainerState
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerState
:members:
TrainerControl
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerControl
:members:
...@@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override ...@@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset. - **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset. - **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training. - **log** -- Logs information on the various objects watching training.
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at - **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init. init.
- **compute_loss** - Computes the loss on a batch of training inputs. - **compute_loss** - Computes the loss on a batch of training inputs.
...@@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu ...@@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu
logits = outputs[0] logits = outputs[0]
return my_custom_loss(logits, labels) return my_custom_loss(logits, labels)
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
other ML platforms...) and take decisions (like early stopping).
Trainer Trainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -47,29 +50,23 @@ Trainer ...@@ -47,29 +50,23 @@ Trainer
.. autoclass:: transformers.Trainer .. autoclass:: transformers.Trainer
:members: :members:
TFTrainer TFTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFTrainer .. autoclass:: transformers.TFTrainer
:members: :members:
TrainingArguments TrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainingArguments .. autoclass:: transformers.TrainingArguments
:members: :members:
TFTrainingArguments TFTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFTrainingArguments .. autoclass:: transformers.TFTrainingArguments
:members: :members:
Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EvalPrediction
.. autofunction:: transformers.set_seed
.. autofunction:: transformers.torch_distributed_zero_first
...@@ -9,7 +9,7 @@ from transformers import Trainer ...@@ -9,7 +9,7 @@ from transformers import Trainer
from transformers.configuration_fsmt import FSMTConfig from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
from transformers.trainer import get_tpu_sampler from transformers.trainer_pt_utils import get_tpu_sampler
try: try:
......
...@@ -4,7 +4,8 @@ import tempfile ...@@ -4,7 +4,8 @@ import tempfile
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import slow from transformers.testing_utils import slow
from transformers.trainer_utils import TrainerState, set_seed from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import main from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY from .test_seq2seq_examples import MBART_TINY
......
...@@ -205,7 +205,15 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer ...@@ -205,7 +205,15 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
# Trainer # Trainer
from .trainer_utils import EvalPrediction, TrainerState, set_seed from .trainer_callback import (
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
from .utils import logging from .utils import logging
...@@ -529,7 +537,8 @@ if is_torch_available(): ...@@ -529,7 +537,8 @@ if is_torch_available():
from .tokenization_marian import MarianTokenizer from .tokenization_marian import MarianTokenizer
# Trainer # Trainer
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
else: else:
from .utils.dummy_pt_objects import * from .utils.dummy_pt_objects import *
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
import math import math
import os import os
from .file_utils import is_torch_tpu_available
from .trainer_callback import TrainerCallback
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
from .utils import logging
try: try:
import comet_ml # noqa: F401 import comet_ml # noqa: F401
...@@ -36,15 +41,6 @@ try: ...@@ -36,15 +41,6 @@ try:
except (ImportError): except (ImportError):
_has_ray = False _has_ray = False
# No ML framework or transformer imports above this point
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip
from .utils import logging # isort:skip
logger = logging.get_logger(__name__)
try: try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401 from torch.utils.tensorboard import SummaryWriter # noqa: F401
...@@ -57,9 +53,10 @@ except ImportError: ...@@ -57,9 +53,10 @@ except ImportError:
except ImportError: except ImportError:
_has_tensorboard = False _has_tensorboard = False
# Integration functions: logger = logging.get_logger(__name__)
# Integration functions:
def is_wandb_available(): def is_wandb_available():
return _has_wandb return _has_wandb
...@@ -128,8 +125,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -128,8 +125,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists) # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search. # while doing the ray hp search.
_tb_writer = trainer.tb_writer
trainer.tb_writer = None _tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None trainer.model = None
# Setup default `resources_per_trial` and `reporter`. # Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0: if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
...@@ -182,5 +179,159 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -182,5 +179,159 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs) analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3]) best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config) best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
trainer.tb_writer = _tb_writer if _tb_writer is not None:
trainer.add_callback(_tb_writer)
return best_run return best_run
class TensorBoardCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
<https://www.tensorflow.org/tensorboard>`__.
Args:
tb_writer (:obj:`SummaryWriter`, `optional`):
The writer to use. Will instatiate one if not set.
"""
def __init__(self, tb_writer=None):
assert (
_has_tensorboard
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
self.tb_writer = tb_writer
def on_init_end(self, args, state, control, **kwargs):
if self.tb_writer is None and state.is_world_process_zero:
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
def on_train_begin(self, args, state, control, **kwargs):
if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
def on_log(self, args, state, control, logs=None, **kwargs):
if self.tb_writer:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute.",
v,
type(v),
k,
)
self.tb_writer.flush()
def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
class WandbCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases
<https://www.wandb.com/>`__.
"""
def __init__(self):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
self._initialized = False
def setup(self, args, state, model):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
Set this to a custom string to store results in a different project.
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to disable wandb entirely.
"""
self._initialized = True
if state.is_world_process_zero:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_sanitized_dict()}
if hasattr(model, "config"):
combined_dict = {**model.config.to_dict(), **combined_dict}
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
wandb.log(logs, step=state.global_step)
class CometCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML
<https://www.comet.ml/site/>`__.
"""
def __init__(self):
assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
self._initialized = False
def setup(self, args, state, model):
"""
Setup the optional Comet.ml integration.
Environment:
COMET_MODE (:obj:`str`, `optional`):
"OFFLINE", "ONLINE", or "DISABLED"
COMET_PROJECT_NAME (:obj:`str`, `optional`):
Comet.ml project name for experiments
COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"
For a number of configurable items in the environment,
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
"""
self._initialized = True
if state.is_world_process_zero:
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
experiment = None
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**args)
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**args)
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(model, framework="transformers")
experiment._log_parameters(args, prefix="args/", framework="transformers")
if hasattr(model, "config"):
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import inspect import inspect
import math
import os import os
import re import re
import shutil import shutil
import warnings import warnings
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -15,8 +31,7 @@ from torch import nn ...@@ -15,8 +31,7 @@ from torch import nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler
from tqdm.auto import tqdm, trange
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
...@@ -34,23 +49,35 @@ from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING ...@@ -34,23 +49,35 @@ from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_pt_utils import (
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_tpu_sampler,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
)
from .trainer_utils import ( from .trainer_utils import (
PREFIX_CHECKPOINT_DIR, PREFIX_CHECKPOINT_DIR,
BestRun, BestRun,
EvalPrediction, EvalPrediction,
EvaluationStrategy,
HPSearchBackend, HPSearchBackend,
PredictionOutput, PredictionOutput,
TrainerState,
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
default_hp_space, default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
set_seed, set_seed,
) )
from .training_args import TrainingArguments from .training_args import TrainingArguments
...@@ -60,7 +87,8 @@ from .utils import logging ...@@ -60,7 +87,8 @@ from .utils import logging
_use_native_amp = False _use_native_amp = False
_use_apex = False _use_apex = False
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler." DEFAULT_CALLBACKS = [DefaultFlowCallback]
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"): if version.parse(torch.__version__) < version.parse("1.6"):
...@@ -82,16 +110,20 @@ if is_torch_tpu_available(): ...@@ -82,16 +110,20 @@ if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.parallel_loader as pl
if is_tensorboard_available(): if is_tensorboard_available():
try: from .integrations import TensorBoardCallback
from torch.utils.tensorboard import SummaryWriter
except ImportError: DEFAULT_CALLBACKS.append(TensorBoardCallback)
from tensorboardX import SummaryWriter
if is_wandb_available(): if is_wandb_available():
import wandb from .integrations import WandbCallback
DEFAULT_CALLBACKS.append(WandbCallback)
if is_comet_available(): if is_comet_available():
import comet_ml from .integrations import CometCallback
DEFAULT_CALLBACKS.append(CometCallback)
if is_optuna_available(): if is_optuna_available():
import optuna import optuna
...@@ -102,91 +134,20 @@ if is_ray_available(): ...@@ -102,91 +134,20 @@ if is_ray_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
warnings.warn(w.message, w.category)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
Args:
local_rank (:obj:`int`): The rank of the local process.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert (
len(indices) == self.total_size
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
return iter(indices)
def __len__(self):
return self.num_samples
def get_tpu_sampler(dataset: Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
class Trainer: class Trainer:
""" """
Trainer is a simple but feature-complete training and eval loop for PyTorch, Trainer is a simple but feature-complete training and eval loop for PyTorch,
optimized for 🤗 Transformers. optimized for 🤗 Transformers.
Args: Args:
model (:class:`~transformers.PreTrainedModel`, `optional`): model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed. The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
.. note::
:class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
they work the same way as the 🤗 Transformers models.
args (:class:`~transformers.TrainingArguments`, `optional`): args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments` The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided. with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
...@@ -210,8 +171,11 @@ class Trainer: ...@@ -210,8 +171,11 @@ class Trainer:
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`): compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values. :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
tb_writer (:obj:`SummaryWriter`, `optional`): callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
Object to write to TensorBoard. A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in :doc:`here <callback>`.
If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of A tuple containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by :class:`~transformers.AdamW` on your model and a scheduler given by
...@@ -222,7 +186,7 @@ class Trainer: ...@@ -222,7 +186,7 @@ class Trainer:
def __init__( def __init__(
self, self,
model: PreTrainedModel = None, model: Union[PreTrainedModel, torch.nn.Module] = None,
args: TrainingArguments = None, args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None, data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None, train_dataset: Optional[Dataset] = None,
...@@ -230,7 +194,7 @@ class Trainer: ...@@ -230,7 +194,7 @@ class Trainer:
tokenizer: Optional["PreTrainedTokenizerBase"] = None, tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_init: Callable[[], PreTrainedModel] = None, model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
tb_writer: Optional["SummaryWriter"] = None, callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**kwargs, **kwargs,
): ):
...@@ -259,7 +223,21 @@ class Trainer: ...@@ -259,7 +223,21 @@ class Trainer:
"Passing a `model_init` is incompatible with providing the `optimizers` argument." "Passing a `model_init` is incompatible with providing the `optimizers` argument."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
self.tb_writer = tb_writer callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
# Deprecated arguments
if "tb_writer" in kwargs:
warnings.warn(
"Passing `tb_writer` as a keyword argument is deprecated and won't be possible in a "
+ "future version. Use `TensorBoardCallback(tb_writer=...)` instead and pass it to the `callbacks`"
+ "argument",
FutureWarning,
)
tb_writer = kwargs.pop("tb_writer")
self.remove_callback(TensorBoardCallback)
self.add_callback(TensorBoardCallback(tb_writer=tb_writer))
if "prediction_loss_only" in kwargs: if "prediction_loss_only" in kwargs:
warnings.warn( warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a " "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
...@@ -270,13 +248,6 @@ class Trainer: ...@@ -270,13 +248,6 @@ class Trainer:
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only") self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available():
logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
)
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`. # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
self._loggers_initialized = False self._loggers_initialized = False
...@@ -304,6 +275,7 @@ class Trainer: ...@@ -304,6 +275,7 @@ class Trainer:
self._remove_unused_columns(self.eval_dataset, description="evaluation") self._remove_unused_columns(self.eval_dataset, description="evaluation")
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log. # state at each call to self.log.
self._total_flos = None self._total_flos = None
...@@ -317,6 +289,45 @@ class Trainer: ...@@ -317,6 +289,45 @@ class Trainer:
else ["labels"] else ["labels"]
) )
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
def add_callback(self, callback):
"""
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will instantiate a member of that class.
"""
self.callback_handler.add_callback(callback)
def pop_callback(self, callback):
"""
Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.
If the callback is not found, returns :obj:`None` (and no error is raised).
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will pop the first member of that class found in the list of callbacks.
Returns:
:class:`~transformer.TrainerCallback`: The callback removed, if found.
"""
return self.callback_handler.pop_callback(callback)
def remove_callback(self, callback):
"""
Remove a callback from the current list of :class:`~transformer.TrainerCallback`.
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will remove the first member of that class found in the list of callbacks.
"""
self.callback_handler.remove_callback(callback)
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
...@@ -465,102 +476,12 @@ class Trainer: ...@@ -465,102 +476,12 @@ class Trainer:
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
) )
def setup_wandb(self):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_WATCH:
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
or "all" to log gradients and parameters
WANDB_PROJECT:
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
"""
if hasattr(self, "_setup_wandb"):
warnings.warn(
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
FutureWarning,
)
return self._setup_wandb()
if self.is_world_process_zero():
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**self.args.to_sanitized_dict()}
if isinstance(self.model, PreTrainedModel):
combined_dict = {**self.model.config.to_dict(), **combined_dict}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
)
def setup_comet(self):
"""
Setup the optional Comet.ml integration.
Environment:
COMET_MODE:
(Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
COMET_PROJECT_NAME:
(Optional): str - Comet.ml project name for experiments
COMET_OFFLINE_DIRECTORY:
(Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
For a number of configurable items in the environment,
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
"""
if self.is_world_master():
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
experiment = None
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**args)
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**args)
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(self.model, framework="transformers")
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
if isinstance(self.model, PreTrainedModel):
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
def num_examples(self, dataloader: DataLoader) -> int: def num_examples(self, dataloader: DataLoader) -> int:
""" """
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset. Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
""" """
return len(dataloader.dataset) return len(dataloader.dataset)
def _setup_loggers(self):
if self._loggers_initialized:
return
if is_wandb_available():
self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true":
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
if is_comet_available():
self.setup_comet()
elif os.environ.get("COMET_MODE") != "DISABLED":
logger.info(
"To use comet_ml logging, run `pip/conda install comet_ml` "
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
self._loggers_initialized = True
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """ """ HP search setup code """
if self.hp_search_backend is None or trial is None: if self.hp_search_backend is None or trial is None:
...@@ -661,7 +582,7 @@ class Trainer: ...@@ -661,7 +582,7 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
# Moxed precision training with apex (torch < 1.6) # Mixed precision training with apex (torch < 1.6)
model = self.model model = self.model
if self.args.fp16 and _use_apex: if self.args.fp16 and _use_apex:
if not is_apex_available(): if not is_apex_available():
...@@ -687,10 +608,6 @@ class Trainer: ...@@ -687,10 +608,6 @@ class Trainer:
# find_unused_parameters breaks checkpointing as per # find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
# Train! # Train!
if is_torch_tpu_available(): if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
...@@ -723,17 +640,25 @@ class Trainer: ...@@ -723,17 +640,25 @@ class Trainer:
logger.info(" Continuing training from global step %d", self.state.global_step) logger.info(" Continuing training from global step %d", self.state.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
# This should be the same if the state has been saved but in case the training arguments changed, it's safer # This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load. # to set this after the load.
self.state.max_steps = max_steps self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
tr_loss = torch.tensor(0.0).to(self.args.device) tr_loss = torch.tensor(0.0).to(self.args.device)
self._logging_loss_scalar = 0
self._total_flos = self.state.total_flos self._total_flos = self.state.total_flos
logging_loss_scalar = 0.0
model.zero_grad() model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm) self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
...@@ -750,15 +675,18 @@ class Trainer: ...@@ -750,15 +675,18 @@ class Trainer:
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm) self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training # Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0: if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1 steps_trained_in_current_epoch -= 1
epoch_pbar.update(1)
continue continue
if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
self._total_flos += self.floating_point_ops(inputs) self._total_flos += self.floating_point_ops(inputs)
...@@ -787,50 +715,15 @@ class Trainer: ...@@ -787,50 +715,15 @@ class Trainer:
model.zero_grad() model.zero_grad()
self.state.global_step += 1 self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / len(epoch_iterator) self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or ( if self.control.should_epoch_stop or self.control.should_training_stop:
self.state.global_step == 1 and self.args.logging_first_step
):
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
logging_loss_scalar = tr_loss_scalar
self.log(logs)
if (
self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.state.global_step % self.args.eval_steps == 0
):
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if (
not self.args.load_best_model_at_end
and self.args.save_steps > 0
and self.state.global_step % self.args.save_steps == 0
):
self._save_training(model, trial)
epoch_pbar.update(1)
if self.state.global_step >= max_steps:
break break
epoch_pbar.close()
train_pbar.update(1)
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH: self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
metrics = self.evaluate() self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if self.args.tpu_metrics_debug or self.args.debug: if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -841,12 +734,9 @@ class Trainer: ...@@ -841,12 +734,9 @@ class Trainer:
"You enabled PyTorch/XLA debug metrics but you don't have a TPU " "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected." "configured. Check your training configuration if this is unexpected."
) )
if self.state.global_step >= max_steps: if self.control.should_training_stop:
break break
train_pbar.close()
if self.tb_writer:
self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"): if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training # Clean the state at the end of training
delattr(self, "_past") delattr(self, "_past")
...@@ -863,9 +753,36 @@ class Trainer: ...@@ -863,9 +753,36 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _save_training(self, model, trial, metrics=None): def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
self._logging_loss_scalar = tr_loss_scalar
self.log(logs)
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference # In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save. # to the model we want to save.
if hasattr(model, "module"): if hasattr(model, "module"):
...@@ -896,7 +813,7 @@ class Trainer: ...@@ -896,7 +813,7 @@ class Trainer:
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint # Determine the new best metric / best model checkpoint
if metrics is not None: if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"): if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}" metric_to_check = f"eval_{metric_to_check}"
...@@ -998,7 +915,7 @@ class Trainer: ...@@ -998,7 +915,7 @@ class Trainer:
self.hp_search_backend = None self.hp_search_backend = None
return best_run return best_run
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: def log(self, logs: Dict[str, float]) -> None:
""" """
Log :obj:`logs` on the various objects watching training. Log :obj:`logs` on the various objects watching training.
...@@ -1007,55 +924,22 @@ class Trainer: ...@@ -1007,55 +924,22 @@ class Trainer:
Args: Args:
logs (:obj:`Dict[str, float]`): logs (:obj:`Dict[str, float]`):
The values to log. The values to log.
iterator (:obj:`tqdm`, `optional`):
A potential tqdm progress bar to write the logs on.
""" """
# Set up loggers like W&B or Comet ML
self._setup_loggers()
if hasattr(self, "_log"): if hasattr(self, "_log"):
warnings.warn( warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.", "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
FutureWarning, FutureWarning,
) )
return self._log(logs, iterator=iterator) return self._log(logs)
if self.state.epoch is not None: if self.state.epoch is not None:
logs["epoch"] = self.state.epoch logs["epoch"] = self.state.epoch
if self._total_flos is not None: if self._total_flos is not None:
self.store_flos() self.store_flos()
logs["total_flos"] = self.state.total_flos logs["total_flos"] = self.state.total_flos
if self.tb_writer: self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, self.state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute.",
v,
type(v),
k,
)
self.tb_writer.flush()
if is_wandb_available():
if self.is_world_process_zero():
wandb.log(logs, step=self.state.global_step)
if is_comet_available():
if self.is_world_process_zero():
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(
logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
)
output = {**logs, **{"step": self.state.global_step}} output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output) self.state.log_history.append(output)
if iterator is not None:
iterator.write(output)
else:
print(output)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
""" """
...@@ -1372,8 +1256,9 @@ class Trainer: ...@@ -1372,8 +1256,9 @@ class Trainer:
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm self.callback_handler.eval_dataloader = dataloader
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
for inputs in dataloader:
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0] batch_size = inputs[list(inputs.keys())[0]].shape[0]
if loss is not None: if loss is not None:
...@@ -1382,6 +1267,7 @@ class Trainer: ...@@ -1382,6 +1267,7 @@ class Trainer:
preds = logits if preds is None else nested_concat(preds, logits, dim=0) preds = logits if preds is None else nested_concat(preds, logits, dim=0)
if labels is not None: if labels is not None:
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0) label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
if self.args.past_index and hasattr(self, "_past"): if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop # Clean the state at the end of the evaluation loop
......
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Callbacks to use with the Trainer class and customize the training loop.
"""
import dataclasses
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy
from .training_args import TrainingArguments
from .utils import logging
logger = logging.get_logger(__name__)
@dataclass
class TrainerState:
"""
A class containing the :class:`~transformers.Trainer` inner state that will be saved along the model and optimizer
when checkpointing and passed to the :class:`~transformers.TrainerCallback`.
.. note::
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
then one update step requires going throuch `n` batches.
Args:
epoch (:obj:`float`, `optional`):
Only set during training, will represent the epoch the training is at (the decimal part being the
percentage of the current epoch completed).
global_step (:obj:`int`, `optional`, defaults to 0):
During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training.
log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`):
When tracking the best model, the value of the best metric encountered so far.
best_model_checkpoint (:obj:`str`, `optional`):
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
far.
is_local_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
"""
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
num_train_epochs: int = 0
total_flos: int = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True
is_world_process_zero: bool = True
def __post_init__(self):
if self.log_history is None:
self.log_history = []
def save_to_json(self, json_path: str):
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
""" Create an instance from the content of :obj:`json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
@dataclass
class TrainerControl:
"""
A class that handles the :class:`~transformers.Trainer` control flow. This class is used by the
:class:`~transformers.TrainerCallback` to activate some switches in the training loop.
Args:
should_training_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the training should be interrupted.
If :obj:`True`, this variable will not be set back to :obj:`False`. The training will just stop.
should_epoch_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the current epoch should be interrupted.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next epoch.
should_save (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should be saved at this step.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
should_evaluate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should be evaluated at this step.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
should_log (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the logs should be reported at this step.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
"""
should_training_stop: bool = False
should_epoch_stop: bool = False
should_save: bool = False
should_evaluate: bool = False
should_log: bool = False
def _new_training(self):
""" Internal method that resets the variable for a new training. """
self.should_training_stop = False
def _new_epoch(self):
""" Internal method that resets the variable for a new epoch. """
self.should_epoch_stop = False
def _new_step(self):
""" Internal method that resets the variable for a new step. """
self.should_save_model = False
self.should_evaluate = False
self.should_log = False
class TrainerCallback:
"""
A class for objects that will inspect the state of the training loop at some events and take some decisions. At
each of those events the following arguments are available:
Args:
args (:class:`~transformers.TrainingArguments`):
The training arguments used to instantiate the :class:`~transformers.Trainer`.
state (:class:`~transformers.TrainerState`):
The current state of the :class:`~transformers.Trainer`.
control (:class:`~transformers.TrainerControl`):
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
The model being trained.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
The scheduler used for setting the learning rate.
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
The current dataloader used for training.
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
The current dataloader used for training.
metrics (:obj:`Dict[str, float]`):
The metrics computed by the last evaluation phase.
Those are only accessible in the event :obj:`on_evaluate`.
logs (:obj:`Dict[str, float]`):
The values to log.
Those are only accessible in the event :obj:`on_log`.
The :obj:`control` object is the only one that can be changed by the callback, in which case the event that changes
it should return the modified version.
The argument :obj:`args`, :obj:`state` and :obj:`control` are positionals for all events, all the others are
grouped in :obj:`kwargs`. You can unpack the ones you need in the signature of the event using them. As an example,
see the code of the simple :class:`~transformer.PrinterCallback`.
Example::
class PrinterCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
"""
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of the initialization of the :class:`~transformers.Trainer`.
"""
pass
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of training.
"""
pass
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of training.
"""
pass
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of an epoch.
"""
pass
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of an epoch.
"""
pass
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
pass
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
pass
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after an evaluation phase.
"""
pass
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after a checkpoint save.
"""
pass
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after logging the last logs.
"""
pass
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after a prediction step.
"""
pass
class CallbackHandler(TrainerCallback):
""" Internal class that just calls the list of callbacks in order. """
def __init__(self, callbacks, model, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = None
self.eval_dataloader = None
if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
logger.warn(
"The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
+ "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
+ "callbacks is\n:"
+ self.callback_list
)
def add_callback(self, callback):
cb = callback() if isinstance(callback, type) else callback
cb_class = callback if isinstance(callback, type) else callback.__class__
if cb_class in [c.__class__ for c in self.callbacks]:
logger.warn(
f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
+ "list of callbacks is\n:"
+ self.callback_list
)
self.callbacks.append(cb)
def pop_callback(self, callback):
if isinstance(callback, type):
for cb in self.callbacks:
if isinstance(cb, callback):
self.callbacks.remove(cb)
return cb
else:
for cb in self.callbacks:
if cb == callback:
self.callbacks.remove(cb)
return cb
def remove_callback(self, callback):
if isinstance(callback, type):
for cb in self.callbacks:
if isinstance(cb, callback):
self.callbacks.remove(cb)
return
else:
self.callbacks.remove(callback)
@property
def callback_list(self):
return "\n".join(self.callbacks)
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_init_end", args, state, control)
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_training_stop = False
return self.call_event("on_train_begin", args, state, control)
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_train_end", args, state, control)
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_epoch_stop = False
return self.call_event("on_epoch_begin", args, state, control)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_epoch_end", args, state, control)
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_log = False
control.should_evaluate = False
control.should_save = False
return self.call_event("on_step_begin", args, state, control)
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_step_end", args, state, control)
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
control.should_evaluate = False
return self.call_event("on_evaluate", args, state, control, metrics=metrics)
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_save = False
return self.call_event("on_save", args, state, control)
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
control.should_log = False
return self.call_event("on_log", args, state, control, logs=logs)
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_prediction_step", args, state, control)
def call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
args,
state,
control,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader,
**kwargs,
)
# A Callback can skip the return of `control` if it doesn't change it.
if result is not None:
control = result
return control
class DefaultFlowCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
and checkpoints.
"""
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# Log
if state.global_step == 1 and args.logging_first_step:
control.should_log = True
if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
control.should_log = True
# Evaluate
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
# Save
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
control.should_save = True
# End training
if state.global_step >= state.max_steps:
control.should_training_stop = True
return control
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
return control
class ProgressCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation.
"""
def __init__(self):
self.training_bar = None
self.prediction_bar = None
def on_train_begin(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar = tqdm(total=state.max_steps)
def on_step_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar.update(1)
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero:
if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1)
def on_evaluate(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.prediction_bar.close()
self.prediction_bar = None
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None)
self.training_bar.write(str(logs))
def on_train_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar.close()
self.training_bar = None
class PrinterCallback(TrainerCallback):
"""
A bare :class:`~transformers.TrainerCallback` that just prints the logs.
"""
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Torch utilities for the Trainer class.
"""
import math
import warnings
from contextlib import contextmanager
from typing import List, Optional, Union
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_torch_tpu_available
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()
def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor:
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
def distributed_broadcast_scalars(
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor:
try:
tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
warnings.warn(w.message, w.category)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
Args:
local_rank (:obj:`int`): The rank of the local process.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert (
len(indices) == self.total_size
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
return iter(indices)
def __len__(self):
return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
import dataclasses # coding=utf-8
import json # Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
"""
import random import random
from dataclasses import dataclass from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum from .tokenization_utils_base import ExplicitEnum
if is_torch_available():
import torch
def set_seed(seed: int): def set_seed(seed: int):
""" """
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
...@@ -139,144 +150,3 @@ default_hp_space = { ...@@ -139,144 +150,3 @@ default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna, HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray, HPSearchBackend.RAY: default_hp_space_ray,
} }
def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
if is_torch_available():
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
else:
raise ImportError("Torch must be installed to use `nested_concat`")
def nested_deatch(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()
def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
if is_torch_available():
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_concat`")
def distributed_broadcast_scalars(
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> "torch.Tensor":
if is_torch_available():
try:
tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
@dataclass
class TrainerState:
"""
A class containing the `Trainer` inner state that will be saved along the model and optimizer.
.. note::
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
then one update step requires going throuch `n` batches.
Args:
epoch (:obj:`float`, `optional`):
Only set during training, will represent the epoch the training is at (the decimal part being the
percentage of the current epoch completed).
global_step (:obj:`int`, `optional`, defaults to 0):
During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training.
log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`):
When tracking the best model, the value of the best metric encountered so far.
best_model_checkpoint (:obj:`str`, `optional`):
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
far.
"""
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
num_train_epochs: int = 0
total_flos: int = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
def __post_init__(self):
if self.log_history is None:
self.log_history = []
def save_to_json(self, json_path: str):
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
""" Create an instance from the content of :obj:`json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
...@@ -54,7 +54,7 @@ class TrainingArguments: ...@@ -54,7 +54,7 @@ class TrainingArguments:
:obj:`"no"`. :obj:`"no"`.
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run predictions on the test set or not. Whether to run predictions on the test set or not.
evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are: The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training. * :obj:`"no"`: No evaluation is done during training.
......
...@@ -1869,19 +1869,10 @@ class MarianTokenizer: ...@@ -1869,19 +1869,10 @@ class MarianTokenizer:
requires_pytorch(self) requires_pytorch(self)
class EvalPrediction:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class Trainer: class Trainer:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
def set_seed(*args, **kwargs):
requires_pytorch(set_seed)
def torch_distributed_zero_first(*args, **kwargs): def torch_distributed_zero_first(*args, **kwargs):
requires_pytorch(torch_distributed_zero_first) requires_pytorch(torch_distributed_zero_first)
import shutil
import tempfile
import unittest
from transformers import (
DefaultFlowCallback,
EvaluationStrategy,
PrinterCallback,
ProgressCallback,
Trainer,
TrainerCallback,
TrainingArguments,
is_torch_available,
)
from transformers.testing_utils import require_torch
if is_torch_available():
from transformers.trainer import DEFAULT_CALLBACKS
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
class TestTrainerCallback(TrainerCallback):
"A callback that registers the events that goes through."
def __init__(self):
self.events = []
def on_init_end(self, args, state, control, **kwargs):
self.events.append("on_init_end")
def on_train_begin(self, args, state, control, **kwargs):
self.events.append("on_train_begin")
def on_train_end(self, args, state, control, **kwargs):
self.events.append("on_train_end")
def on_epoch_begin(self, args, state, control, **kwargs):
self.events.append("on_epoch_begin")
def on_epoch_end(self, args, state, control, **kwargs):
self.events.append("on_epoch_end")
def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin")
def on_step_end(self, args, state, control, **kwargs):
self.events.append("on_step_end")
def on_evaluate(self, args, state, control, **kwargs):
self.events.append("on_evaluate")
def on_save(self, args, state, control, **kwargs):
self.events.append("on_save")
def on_log(self, args, state, control, **kwargs):
self.events.append("on_log")
def on_prediction_step(self, args, state, control, **kwargs):
self.events.append("on_prediction_step")
@require_torch
class TrainerCallbackTest(unittest.TestCase):
def setUp(self):
self.output_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.output_dir)
def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs):
# disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure
# its set to False since the tests later on depend on its value.
train_dataset = RegressionDataset(length=train_len)
eval_dataset = RegressionDataset(length=eval_len)
config = RegressionModelConfig(a=a, b=b)
model = RegressionPreTrainedModel(config)
args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, **kwargs)
return Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=callbacks,
)
def check_callbacks_equality(self, cbs1, cbs2):
self.assertEqual(len(cbs1), len(cbs2))
# Order doesn't matter
cbs1 = list(sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
cbs2 = list(sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
for cb1, cb2 in zip(cbs1, cbs2):
if isinstance(cb1, type) and isinstance(cb2, type):
self.assertEqual(cb1, cb2)
elif isinstance(cb1, type) and not isinstance(cb2, type):
self.assertEqual(cb1, cb2.__class__)
elif not isinstance(cb1, type) and isinstance(cb2, type):
self.assertEqual(cb1.__class__, cb2)
else:
self.assertEqual(cb1, cb2)
def get_expected_events(self, trainer):
expected_events = ["on_init_end", "on_train_begin"]
step = 0
train_dl_len = len(trainer.get_eval_dataloader())
evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"]
for _ in range(trainer.state.num_train_epochs):
expected_events.append("on_epoch_begin")
for _ in range(train_dl_len):
step += 1
expected_events += ["on_step_begin", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if (
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
and step % trainer.args.eval_steps == 0
):
expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0:
expected_events.append("on_save")
expected_events.append("on_epoch_end")
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
expected_events += evaluation_events.copy()
expected_events.append("on_train_end")
return expected_events
def test_init_callback(self):
trainer = self.get_trainer()
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# Callbacks passed at init are added to the default callbacks
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
expected_callbacks.append(TestTrainerCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
trainer = self.get_trainer(disable_tqdm=True)
expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback]
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
def test_add_remove_callback(self):
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
trainer = self.get_trainer()
# We can add, pop, or remove by class name
trainer.remove_callback(DefaultFlowCallback)
expected_callbacks.remove(DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer = self.get_trainer()
cb = trainer.pop_callback(DefaultFlowCallback)
self.assertEqual(cb.__class__, DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer.add_callback(DefaultFlowCallback)
expected_callbacks.insert(0, DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# We can also add, pop, or remove by instance
trainer = self.get_trainer()
cb = trainer.callback_handler.callbacks[0]
trainer.remove_callback(cb)
expected_callbacks.remove(DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer = self.get_trainer()
cb1 = trainer.callback_handler.callbacks[0]
cb2 = trainer.pop_callback(cb1)
self.assertEqual(cb1, cb2)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer.add_callback(cb1)
expected_callbacks.insert(0, DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
def test_event_flow(self):
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# Independent log/save/eval
trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# A bit of everything
trainer = self.get_trainer(
callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps"
)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment