Unverified Commit 08ba4b49 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer callbacks (#7596)



* Initial callback proposal

* Finish various callbacks

* Post-rebase conflicts

* Fix tests

* Don't use something that's not set

* Documentation

* Remove unwanted print.

* Document all models can work

* Add tests + small fixes

* Update docs/source/internal/trainer_utils.rst
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Address review comments

* Fix TF tests

* Real fix this time

* This one should work

* Fix typo

* Really fix typo
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 8fa0c956
......@@ -213,6 +213,7 @@ conversion utilities for the following models:
:maxdepth: 2
:caption: Main Classes
main_classes/callback
main_classes/configuration
main_classes/logging
main_classes/model
......@@ -270,3 +271,4 @@ conversion utilities for the following models:
internal/modeling_utils
internal/pipelines_utils
internal/tokenization_utils
internal/trainer_utils
Utilities for Trainer
-----------------------------------------------------------------------------------------------------------------------
This page lists all the utility functions used by :class:`~transformers.Trainer`.
Most of those are only useful if you are studying the code of the Trainer in the library.
Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EvalPrediction
.. autofunction:: transformers.set_seed
.. autofunction:: transformers.torch_distributed_zero_first
Callbacks internals
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.trainer_callback.CallbackHandler
Callbacks
-----------------------------------------------------------------------------------------------------------------------
Callbacks are objects that can customize the behavior of the training loop in the PyTorch
:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
stopping).
Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).
By default a :class:`~transformers.Trainer` will use the following callbacks:
- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation.
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
it's the second one).
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
or tensorboardX).
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
:class:`~transformers.TrainerControl`.
Available Callbacks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here is the list of the available :class:`~transformers.TrainerCallback` in the library:
.. autoclass:: transformers.integrations.CometCallback
:members: setup
.. autoclass:: transformers.DefaultFlowCallback
.. autoclass:: transformers.PrinterCallback
.. autoclass:: transformers.ProgressCallback
.. autoclass:: transformers.integrations.TensorBoardCallback
.. autoclass:: transformers.integrations.WandbCallback
:members: setup
TrainerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerCallback
:members:
TrainerState
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerState
:members:
TrainerControl
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainerControl
:members:
......@@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **compute_loss** - Computes the loss on a batch of training inputs.
......@@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu
logits = outputs[0]
return my_custom_loss(logits, labels)
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
other ML platforms...) and take decisions (like early stopping).
Trainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -47,29 +50,23 @@ Trainer
.. autoclass:: transformers.Trainer
:members:
TFTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFTrainer
:members:
TrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainingArguments
:members:
TFTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFTrainingArguments
:members:
Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EvalPrediction
.. autofunction:: transformers.set_seed
.. autofunction:: transformers.torch_distributed_zero_first
......@@ -9,7 +9,7 @@ from transformers import Trainer
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
from transformers.trainer import get_tpu_sampler
from transformers.trainer_pt_utils import get_tpu_sampler
try:
......
......@@ -4,7 +4,8 @@ import tempfile
from unittest.mock import patch
from transformers.testing_utils import slow
from transformers.trainer_utils import TrainerState, set_seed
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY
......
......@@ -205,7 +205,15 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
# Trainer
from .trainer_utils import EvalPrediction, TrainerState, set_seed
from .trainer_callback import (
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging
......@@ -529,7 +537,8 @@ if is_torch_available():
from .tokenization_marian import MarianTokenizer
# Trainer
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
else:
from .utils.dummy_pt_objects import *
......
......@@ -2,6 +2,11 @@
import math
import os
from .file_utils import is_torch_tpu_available
from .trainer_callback import TrainerCallback
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
from .utils import logging
try:
import comet_ml # noqa: F401
......@@ -36,15 +41,6 @@ try:
except (ImportError):
_has_ray = False
# No ML framework or transformer imports above this point
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip
from .utils import logging # isort:skip
logger = logging.get_logger(__name__)
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
......@@ -57,9 +53,10 @@ except ImportError:
except ImportError:
_has_tensorboard = False
# Integration functions:
logger = logging.get_logger(__name__)
# Integration functions:
def is_wandb_available():
return _has_wandb
......@@ -128,8 +125,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.tb_writer
trainer.tb_writer = None
_tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
......@@ -182,5 +179,159 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
trainer.tb_writer = _tb_writer
if _tb_writer is not None:
trainer.add_callback(_tb_writer)
return best_run
class TensorBoardCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
<https://www.tensorflow.org/tensorboard>`__.
Args:
tb_writer (:obj:`SummaryWriter`, `optional`):
The writer to use. Will instatiate one if not set.
"""
def __init__(self, tb_writer=None):
assert (
_has_tensorboard
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
self.tb_writer = tb_writer
def on_init_end(self, args, state, control, **kwargs):
if self.tb_writer is None and state.is_world_process_zero:
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
def on_train_begin(self, args, state, control, **kwargs):
if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
def on_log(self, args, state, control, logs=None, **kwargs):
if self.tb_writer:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute.",
v,
type(v),
k,
)
self.tb_writer.flush()
def on_train_end(self, args, state, control, **kwargs):
if self.tb_writer:
self.tb_writer.close()
class WandbCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases
<https://www.wandb.com/>`__.
"""
def __init__(self):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
self._initialized = False
def setup(self, args, state, model):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
Set this to a custom string to store results in a different project.
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to disable wandb entirely.
"""
self._initialized = True
if state.is_world_process_zero:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_sanitized_dict()}
if hasattr(model, "config"):
combined_dict = {**model.config.to_dict(), **combined_dict}
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
wandb.log(logs, step=state.global_step)
class CometCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML
<https://www.comet.ml/site/>`__.
"""
def __init__(self):
assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
self._initialized = False
def setup(self, args, state, model):
"""
Setup the optional Comet.ml integration.
Environment:
COMET_MODE (:obj:`str`, `optional`):
"OFFLINE", "ONLINE", or "DISABLED"
COMET_PROJECT_NAME (:obj:`str`, `optional`):
Comet.ml project name for experiments
COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"
For a number of configurable items in the environment,
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
"""
self._initialized = True
if state.is_world_process_zero:
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
experiment = None
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**args)
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**args)
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(model, framework="transformers")
experiment._log_parameters(args, prefix="args/", framework="transformers")
if hasattr(model, "config"):
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
This diff is collapsed.
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Callbacks to use with the Trainer class and customize the training loop.
"""
import dataclasses
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy
from .training_args import TrainingArguments
from .utils import logging
logger = logging.get_logger(__name__)
@dataclass
class TrainerState:
"""
A class containing the :class:`~transformers.Trainer` inner state that will be saved along the model and optimizer
when checkpointing and passed to the :class:`~transformers.TrainerCallback`.
.. note::
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
then one update step requires going throuch `n` batches.
Args:
epoch (:obj:`float`, `optional`):
Only set during training, will represent the epoch the training is at (the decimal part being the
percentage of the current epoch completed).
global_step (:obj:`int`, `optional`, defaults to 0):
During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training.
log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`):
When tracking the best model, the value of the best metric encountered so far.
best_model_checkpoint (:obj:`str`, `optional`):
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
far.
is_local_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
"""
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
num_train_epochs: int = 0
total_flos: int = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True
is_world_process_zero: bool = True
def __post_init__(self):
if self.log_history is None:
self.log_history = []
def save_to_json(self, json_path: str):
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
""" Create an instance from the content of :obj:`json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
@dataclass
class TrainerControl:
"""
A class that handles the :class:`~transformers.Trainer` control flow. This class is used by the
:class:`~transformers.TrainerCallback` to activate some switches in the training loop.
Args:
should_training_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the training should be interrupted.
If :obj:`True`, this variable will not be set back to :obj:`False`. The training will just stop.
should_epoch_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the current epoch should be interrupted.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next epoch.
should_save (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should be saved at this step.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
should_evaluate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should be evaluated at this step.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
should_log (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the logs should be reported at this step.
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
"""
should_training_stop: bool = False
should_epoch_stop: bool = False
should_save: bool = False
should_evaluate: bool = False
should_log: bool = False
def _new_training(self):
""" Internal method that resets the variable for a new training. """
self.should_training_stop = False
def _new_epoch(self):
""" Internal method that resets the variable for a new epoch. """
self.should_epoch_stop = False
def _new_step(self):
""" Internal method that resets the variable for a new step. """
self.should_save_model = False
self.should_evaluate = False
self.should_log = False
class TrainerCallback:
"""
A class for objects that will inspect the state of the training loop at some events and take some decisions. At
each of those events the following arguments are available:
Args:
args (:class:`~transformers.TrainingArguments`):
The training arguments used to instantiate the :class:`~transformers.Trainer`.
state (:class:`~transformers.TrainerState`):
The current state of the :class:`~transformers.Trainer`.
control (:class:`~transformers.TrainerControl`):
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
The model being trained.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
The scheduler used for setting the learning rate.
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
The current dataloader used for training.
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
The current dataloader used for training.
metrics (:obj:`Dict[str, float]`):
The metrics computed by the last evaluation phase.
Those are only accessible in the event :obj:`on_evaluate`.
logs (:obj:`Dict[str, float]`):
The values to log.
Those are only accessible in the event :obj:`on_log`.
The :obj:`control` object is the only one that can be changed by the callback, in which case the event that changes
it should return the modified version.
The argument :obj:`args`, :obj:`state` and :obj:`control` are positionals for all events, all the others are
grouped in :obj:`kwargs`. You can unpack the ones you need in the signature of the event using them. As an example,
see the code of the simple :class:`~transformer.PrinterCallback`.
Example::
class PrinterCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
"""
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of the initialization of the :class:`~transformers.Trainer`.
"""
pass
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of training.
"""
pass
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of training.
"""
pass
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of an epoch.
"""
pass
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of an epoch.
"""
pass
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
pass
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
pass
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after an evaluation phase.
"""
pass
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after a checkpoint save.
"""
pass
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after logging the last logs.
"""
pass
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after a prediction step.
"""
pass
class CallbackHandler(TrainerCallback):
""" Internal class that just calls the list of callbacks in order. """
def __init__(self, callbacks, model, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = None
self.eval_dataloader = None
if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
logger.warn(
"The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
+ "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
+ "callbacks is\n:"
+ self.callback_list
)
def add_callback(self, callback):
cb = callback() if isinstance(callback, type) else callback
cb_class = callback if isinstance(callback, type) else callback.__class__
if cb_class in [c.__class__ for c in self.callbacks]:
logger.warn(
f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
+ "list of callbacks is\n:"
+ self.callback_list
)
self.callbacks.append(cb)
def pop_callback(self, callback):
if isinstance(callback, type):
for cb in self.callbacks:
if isinstance(cb, callback):
self.callbacks.remove(cb)
return cb
else:
for cb in self.callbacks:
if cb == callback:
self.callbacks.remove(cb)
return cb
def remove_callback(self, callback):
if isinstance(callback, type):
for cb in self.callbacks:
if isinstance(cb, callback):
self.callbacks.remove(cb)
return
else:
self.callbacks.remove(callback)
@property
def callback_list(self):
return "\n".join(self.callbacks)
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_init_end", args, state, control)
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_training_stop = False
return self.call_event("on_train_begin", args, state, control)
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_train_end", args, state, control)
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_epoch_stop = False
return self.call_event("on_epoch_begin", args, state, control)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_epoch_end", args, state, control)
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_log = False
control.should_evaluate = False
control.should_save = False
return self.call_event("on_step_begin", args, state, control)
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_step_end", args, state, control)
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
control.should_evaluate = False
return self.call_event("on_evaluate", args, state, control, metrics=metrics)
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
control.should_save = False
return self.call_event("on_save", args, state, control)
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
control.should_log = False
return self.call_event("on_log", args, state, control, logs=logs)
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_prediction_step", args, state, control)
def call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
args,
state,
control,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader,
**kwargs,
)
# A Callback can skip the return of `control` if it doesn't change it.
if result is not None:
control = result
return control
class DefaultFlowCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
and checkpoints.
"""
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# Log
if state.global_step == 1 and args.logging_first_step:
control.should_log = True
if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
control.should_log = True
# Evaluate
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
# Save
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
control.should_save = True
# End training
if state.global_step >= state.max_steps:
control.should_training_stop = True
return control
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
control.should_evaluate = True
if args.load_best_model_at_end:
control.should_save = True
return control
class ProgressCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation.
"""
def __init__(self):
self.training_bar = None
self.prediction_bar = None
def on_train_begin(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar = tqdm(total=state.max_steps)
def on_step_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar.update(1)
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero:
if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1)
def on_evaluate(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.prediction_bar.close()
self.prediction_bar = None
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None)
self.training_bar.write(str(logs))
def on_train_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar.close()
self.training_bar = None
class PrinterCallback(TrainerCallback):
"""
A bare :class:`~transformers.TrainerCallback` that just prints the logs.
"""
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Torch utilities for the Trainer class.
"""
import math
import warnings
from contextlib import contextmanager
from typing import List, Optional, Union
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_torch_tpu_available
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()
def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor:
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
def distributed_broadcast_scalars(
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor:
try:
tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
warnings.warn(w.message, w.category)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
Args:
local_rank (:obj:`int`): The rank of the local process.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert (
len(indices) == self.total_size
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
return iter(indices)
def __len__(self):
return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
import dataclasses
import json
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
"""
import random
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum
if is_torch_available():
import torch
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
......@@ -139,144 +150,3 @@ default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
}
def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
if is_torch_available():
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
else:
raise ImportError("Torch must be installed to use `nested_concat`")
def nested_deatch(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_numpify(t) for t in tensors)
return tensors.cpu().numpy()
def nested_detach(tensors):
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t) for t in tensors)
return tensors.detach()
def nested_xla_mesh_reduce(tensors, name):
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
if is_torch_available():
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_concat`")
def distributed_broadcast_scalars(
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> "torch.Tensor":
if is_torch_available():
try:
tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
@dataclass
class TrainerState:
"""
A class containing the `Trainer` inner state that will be saved along the model and optimizer.
.. note::
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
then one update step requires going throuch `n` batches.
Args:
epoch (:obj:`float`, `optional`):
Only set during training, will represent the epoch the training is at (the decimal part being the
percentage of the current epoch completed).
global_step (:obj:`int`, `optional`, defaults to 0):
During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training.
log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`):
When tracking the best model, the value of the best metric encountered so far.
best_model_checkpoint (:obj:`str`, `optional`):
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
far.
"""
epoch: Optional[float] = None
global_step: int = 0
max_steps: int = 0
num_train_epochs: int = 0
total_flos: int = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None
def __post_init__(self):
if self.log_history is None:
self.log_history = []
def save_to_json(self, json_path: str):
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
""" Create an instance from the content of :obj:`json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
......@@ -54,7 +54,7 @@ class TrainingArguments:
:obj:`"no"`.
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run predictions on the test set or not.
evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training.
......
......@@ -1869,19 +1869,10 @@ class MarianTokenizer:
requires_pytorch(self)
class EvalPrediction:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class Trainer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
def set_seed(*args, **kwargs):
requires_pytorch(set_seed)
def torch_distributed_zero_first(*args, **kwargs):
requires_pytorch(torch_distributed_zero_first)
import shutil
import tempfile
import unittest
from transformers import (
DefaultFlowCallback,
EvaluationStrategy,
PrinterCallback,
ProgressCallback,
Trainer,
TrainerCallback,
TrainingArguments,
is_torch_available,
)
from transformers.testing_utils import require_torch
if is_torch_available():
from transformers.trainer import DEFAULT_CALLBACKS
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
class TestTrainerCallback(TrainerCallback):
"A callback that registers the events that goes through."
def __init__(self):
self.events = []
def on_init_end(self, args, state, control, **kwargs):
self.events.append("on_init_end")
def on_train_begin(self, args, state, control, **kwargs):
self.events.append("on_train_begin")
def on_train_end(self, args, state, control, **kwargs):
self.events.append("on_train_end")
def on_epoch_begin(self, args, state, control, **kwargs):
self.events.append("on_epoch_begin")
def on_epoch_end(self, args, state, control, **kwargs):
self.events.append("on_epoch_end")
def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin")
def on_step_end(self, args, state, control, **kwargs):
self.events.append("on_step_end")
def on_evaluate(self, args, state, control, **kwargs):
self.events.append("on_evaluate")
def on_save(self, args, state, control, **kwargs):
self.events.append("on_save")
def on_log(self, args, state, control, **kwargs):
self.events.append("on_log")
def on_prediction_step(self, args, state, control, **kwargs):
self.events.append("on_prediction_step")
@require_torch
class TrainerCallbackTest(unittest.TestCase):
def setUp(self):
self.output_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.output_dir)
def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs):
# disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure
# its set to False since the tests later on depend on its value.
train_dataset = RegressionDataset(length=train_len)
eval_dataset = RegressionDataset(length=eval_len)
config = RegressionModelConfig(a=a, b=b)
model = RegressionPreTrainedModel(config)
args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, **kwargs)
return Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=callbacks,
)
def check_callbacks_equality(self, cbs1, cbs2):
self.assertEqual(len(cbs1), len(cbs2))
# Order doesn't matter
cbs1 = list(sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
cbs2 = list(sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
for cb1, cb2 in zip(cbs1, cbs2):
if isinstance(cb1, type) and isinstance(cb2, type):
self.assertEqual(cb1, cb2)
elif isinstance(cb1, type) and not isinstance(cb2, type):
self.assertEqual(cb1, cb2.__class__)
elif not isinstance(cb1, type) and isinstance(cb2, type):
self.assertEqual(cb1.__class__, cb2)
else:
self.assertEqual(cb1, cb2)
def get_expected_events(self, trainer):
expected_events = ["on_init_end", "on_train_begin"]
step = 0
train_dl_len = len(trainer.get_eval_dataloader())
evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"]
for _ in range(trainer.state.num_train_epochs):
expected_events.append("on_epoch_begin")
for _ in range(train_dl_len):
step += 1
expected_events += ["on_step_begin", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if (
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
and step % trainer.args.eval_steps == 0
):
expected_events += evaluation_events.copy()
if step % trainer.args.save_steps == 0:
expected_events.append("on_save")
expected_events.append("on_epoch_end")
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
expected_events += evaluation_events.copy()
expected_events.append("on_train_end")
return expected_events
def test_init_callback(self):
trainer = self.get_trainer()
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# Callbacks passed at init are added to the default callbacks
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
expected_callbacks.append(TestTrainerCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
trainer = self.get_trainer(disable_tqdm=True)
expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback]
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
def test_add_remove_callback(self):
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
trainer = self.get_trainer()
# We can add, pop, or remove by class name
trainer.remove_callback(DefaultFlowCallback)
expected_callbacks.remove(DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer = self.get_trainer()
cb = trainer.pop_callback(DefaultFlowCallback)
self.assertEqual(cb.__class__, DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer.add_callback(DefaultFlowCallback)
expected_callbacks.insert(0, DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
# We can also add, pop, or remove by instance
trainer = self.get_trainer()
cb = trainer.callback_handler.callbacks[0]
trainer.remove_callback(cb)
expected_callbacks.remove(DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer = self.get_trainer()
cb1 = trainer.callback_handler.callbacks[0]
cb2 = trainer.pop_callback(cb1)
self.assertEqual(cb1, cb2)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
trainer.add_callback(cb1)
expected_callbacks.insert(0, DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
def test_event_flow(self):
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# Independent log/save/eval
trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
# A bit of everything
trainer = self.get_trainer(
callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps"
)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment