"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1f9dcfc1ef1f2f21b29e61ea4dd4e440bd7c7963"
Unverified Commit ad697f18 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Introduce Stateful Callbacks (#29666)



* Introduce saveable callbacks

* Add note

* Test for non-present and flag

* Support early stopping and refusing to train further

* Update docstring

* More saving

* Import oopsie

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Make it go through TrainerArguments

* Document

* Fix test

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Rework to allow for duplicates

* CLean

* Fix failing tests

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 86f25697
...@@ -78,6 +78,7 @@ from .tokenization_utils_base import PreTrainedTokenizerBase ...@@ -78,6 +78,7 @@ from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
DefaultFlowCallback, DefaultFlowCallback,
ExportableState,
PrinterCallback, PrinterCallback,
ProgressCallback, ProgressCallback,
TrainerCallback, TrainerCallback,
...@@ -649,12 +650,15 @@ class Trainer: ...@@ -649,12 +650,15 @@ class Trainer:
else: else:
self.label_smoother = None self.label_smoother = None
self.control = TrainerControl()
self.state = TrainerState( self.state = TrainerState(
is_local_process_zero=self.is_local_process_zero(), is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(), is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
) )
self.control = TrainerControl()
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
# returned to 0 every time flos need to be logged # returned to 0 every time flos need to be logged
self.current_flos = 0 self.current_flos = 0
...@@ -1499,6 +1503,8 @@ class Trainer: ...@@ -1499,6 +1503,8 @@ class Trainer:
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True) self.save_model(output_dir, _internal_call=True)
if self.args.should_save: if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
...@@ -1970,7 +1976,11 @@ class Trainer: ...@@ -1970,7 +1976,11 @@ class Trainer:
if not delay_optimizer_creation: if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState() self.state = TrainerState(
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]
)
self.state.is_hyper_param_search = trial is not None self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size self.state.train_batch_size = self._train_batch_size
...@@ -2079,6 +2089,7 @@ class Trainer: ...@@ -2079,6 +2089,7 @@ class Trainer:
): ):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state) self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = self.state.global_step // num_update_steps_per_epoch epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip: if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
...@@ -2786,6 +2797,8 @@ class Trainer: ...@@ -2786,6 +2797,8 @@ class Trainer:
# Save the Trainer state # Save the Trainer state
if self.args.should_save: if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
if self.args.push_to_hub: if self.args.push_to_hub:
...@@ -2970,6 +2983,45 @@ class Trainer: ...@@ -2970,6 +2983,45 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
def _load_callback_state(self):
"""If callback states exist and were passed in, restore their states if enabled"""
if not self.args.restore_callback_states_from_checkpoint:
return
# Callback states are stored in stateful_callbacks
not_found = []
new_callbacks = []
original_callbacks = self.callback_handler.callbacks + [self.control]
for stored_callback, data in self.state.stateful_callbacks.items():
if not isinstance(data, list):
data = [data]
if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
# We can load/restore from multiple callbacks of the same type.
duplicates = [
callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
]
for callback, callback_data in zip(duplicates, data):
args = callback_data.get("args", {})
attributes = callback_data.get("attributes", {})
new_callback = type(callback)(**args)
for attribute, value in attributes.items():
setattr(new_callback, attribute, value)
if isinstance(callback, TrainerControl):
# Specifically for restoring the `control` state
self.control = new_callback
else:
new_callbacks.append(new_callback)
# We remove the existing callback and add it to the list of new callbacks
self.callback_handler.remove_callback(type(new_callback))
logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
else:
not_found.append(stored_callback)
if len(not_found) > 0:
logger.warning(
f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
)
for callback in new_callbacks:
self.callback_handler.add_callback(callback)
def hyperparameter_search( def hyperparameter_search(
self, self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
......
...@@ -84,6 +84,9 @@ class TrainerState: ...@@ -84,6 +84,9 @@ class TrainerState:
is_hyper_param_search (`bool`, *optional*, defaults to `False`): is_hyper_param_search (`bool`, *optional*, defaults to `False`):
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
impact the way data will be logged in TensorBoard. impact the way data will be logged in TensorBoard.
stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
Callbacks attached to the `Trainer` that should have their states be saved or restored.
Relevent callbacks should implement a `state` and `from_state` function.
""" """
epoch: Optional[float] = None epoch: Optional[float] = None
...@@ -104,10 +107,34 @@ class TrainerState: ...@@ -104,10 +107,34 @@ class TrainerState:
is_hyper_param_search: bool = False is_hyper_param_search: bool = False
trial_name: str = None trial_name: str = None
trial_params: Dict[str, Union[str, float, int, bool]] = None trial_params: Dict[str, Union[str, float, int, bool]] = None
stateful_callbacks: List["TrainerCallback"] = None
def __post_init__(self): def __post_init__(self):
if self.log_history is None: if self.log_history is None:
self.log_history = [] self.log_history = []
if self.stateful_callbacks is None:
self.stateful_callbacks = {}
elif isinstance(self.stateful_callbacks, dict):
# We are loading the callbacks in from the state file, no need to process them
pass
else:
# Saveable callbacks get stored as dict of kwargs
stateful_callbacks = {}
for callback in self.stateful_callbacks:
if not isinstance(callback, (ExportableState)):
raise TypeError(
f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
)
name = callback.__class__.__name__
if name in stateful_callbacks:
# We can have multiple versions of the same callback
# if so, we store them as a list of states to restore
if not isinstance(stateful_callbacks[name], list):
stateful_callbacks[name] = [stateful_callbacks[name]]
stateful_callbacks[name].append(callback.state())
else:
stateful_callbacks[name] = callback.state()
self.stateful_callbacks = stateful_callbacks
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`.""" """Save the content of this instance in JSON format inside `json_path`."""
...@@ -123,8 +150,52 @@ class TrainerState: ...@@ -123,8 +150,52 @@ class TrainerState:
return cls(**json.loads(text)) return cls(**json.loads(text))
class ExportableState:
"""
A class for objects that include the ability to have its state
be saved during `Trainer._save_checkpoint` and loaded back in during
`Trainer._load_from_checkpoint`.
These must implement a `state` function that gets called during the respective
Trainer function call. It should only include parameters and attributes needed to
recreate the state at a particular time, to avoid utilizing pickle/maintain standard
file IO writing.
Example:
```python
class EarlyStoppingCallback(TrainerCallback, ExportableState):
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
self.early_stopping_patience_counter = 0
def state(self) -> dict:
return {
"args": {
"early_stopping_patience": self.early_stopping_patience,
"early_stopping_threshold": self.early_stopping_threshold,
},
"attributes": {
"early_stopping_patience_counter": self.early_stopping_patience_counter,
}
}
```"""
def state(self) -> dict:
raise NotImplementedError("You must implement a `state` function to utilize this class.")
@classmethod
def from_state(cls, state):
instance = cls(**state["args"])
for k, v in state["attributes"].items():
setattr(instance, k, v)
return instance
@dataclass @dataclass
class TrainerControl: class TrainerControl(ExportableState):
""" """
A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
switches in the training loop. switches in the training loop.
...@@ -172,6 +243,18 @@ class TrainerControl: ...@@ -172,6 +243,18 @@ class TrainerControl:
self.should_evaluate = False self.should_evaluate = False
self.should_log = False self.should_log = False
def state(self) -> dict:
return {
"args": {
"should_training_stop": self.should_training_stop,
"should_epoch_stop": self.should_epoch_stop,
"should_save": self.should_save,
"should_evaluate": self.should_evaluate,
"should_log": self.should_log,
},
"attributes": {},
}
class TrainerCallback: class TrainerCallback:
# no-format # no-format
...@@ -546,7 +629,7 @@ class PrinterCallback(TrainerCallback): ...@@ -546,7 +629,7 @@ class PrinterCallback(TrainerCallback):
print(logs) print(logs)
class EarlyStoppingCallback(TrainerCallback): class EarlyStoppingCallback(TrainerCallback, ExportableState):
""" """
A [`TrainerCallback`] that handles early stopping. A [`TrainerCallback`] that handles early stopping.
...@@ -605,3 +688,14 @@ class EarlyStoppingCallback(TrainerCallback): ...@@ -605,3 +688,14 @@ class EarlyStoppingCallback(TrainerCallback):
self.check_metric_value(args, state, control, metric_value) self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience: if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True control.should_training_stop = True
def state(self) -> dict:
return {
"args": {
"early_stopping_patience": self.early_stopping_patience,
"early_stopping_threshold": self.early_stopping_threshold,
},
"attributes": {
"early_stopping_patience_counter": self.early_stopping_patience_counter,
},
}
...@@ -357,6 +357,9 @@ class TrainingArguments: ...@@ -357,6 +357,9 @@ class TrainingArguments:
Note that when this is true, you won't be able to resume training from checkpoint. Note that when this is true, you won't be able to resume training from checkpoint.
This enables you to save storage by not storing the optimizer, scheduler & rng state. This enables you to save storage by not storing the optimizer, scheduler & rng state.
You can only load the model using `from_pretrained` with this option set to `True`. You can only load the model using `from_pretrained` with this option set to `True`.
restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to restore the callback states from the checkpoint. If `True`, will override
callbacks passed to the `Trainer` if they exist in the checkpoint."
use_cpu (`bool`, *optional*, defaults to `False`): use_cpu (`bool`, *optional*, defaults to `False`):
Whether or not to use cpu. If set to False, we will use cuda or mps device if available. Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
seed (`int`, *optional*, defaults to 42): seed (`int`, *optional*, defaults to 42):
...@@ -951,6 +954,12 @@ class TrainingArguments: ...@@ -951,6 +954,12 @@ class TrainingArguments:
) )
}, },
) )
restore_callback_states_from_checkpoint: bool = field(
default=False,
metadata={
"help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
},
)
no_cuda: bool = field( no_cuda: bool = field(
default=False, default=False,
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."}, metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
...@@ -19,28 +21,44 @@ from unittest.mock import patch ...@@ -19,28 +21,44 @@ from unittest.mock import patch
from transformers import ( from transformers import (
DefaultFlowCallback, DefaultFlowCallback,
EarlyStoppingCallback,
IntervalStrategy, IntervalStrategy,
PrinterCallback, PrinterCallback,
ProgressCallback, ProgressCallback,
Trainer, Trainer,
TrainerCallback, TrainerCallback,
TrainerState,
TrainingArguments, TrainingArguments,
is_torch_available, is_torch_available,
) )
from transformers.testing_utils import require_torch from transformers.testing_utils import require_torch
from transformers.trainer_callback import ExportableState
if is_torch_available(): if is_torch_available():
from transformers.trainer import DEFAULT_CALLBACKS from transformers.trainer import DEFAULT_CALLBACKS, TRAINER_STATE_NAME
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
class MyTestExportableCallback(TrainerCallback, ExportableState):
def __init__(self, my_test_state="test"):
self.my_test_state = my_test_state
def state(self):
return {
"args": {
"my_test_state": self.my_test_state,
},
}
class MyTestTrainerCallback(TrainerCallback): class MyTestTrainerCallback(TrainerCallback):
"A callback that registers the events that goes through." "A callback that registers the events that goes through."
def __init__(self): def __init__(self, my_test_state="test"):
self.events = [] self.events = []
self.my_test_state = my_test_state
def on_init_end(self, args, state, control, **kwargs): def on_init_end(self, args, state, control, **kwargs):
self.events.append("on_init_end") self.events.append("on_init_end")
...@@ -243,3 +261,160 @@ class TrainerCallbackTest(unittest.TestCase): ...@@ -243,3 +261,160 @@ class TrainerCallbackTest(unittest.TestCase):
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
) )
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0] assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
def test_stateful_callbacks(self):
# Use something with non-defaults
cb = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2)
trainer = self.get_trainer(
callbacks=[cb],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
callbacks=[EarlyStoppingCallback()],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.train(resume_from_checkpoint=checkpoint)
cb = [
callback for callback in trainer.callback_handler.callbacks if isinstance(callback, EarlyStoppingCallback)
][0]
assert cb.early_stopping_patience == 5
assert cb.early_stopping_threshold == 0.2
def test_stateful_mixed_callbacks(self):
# Use two callbacks, one stateful one not
# Use something with non-defaults
cbs = [
MyTestTrainerCallback(my_test_state="another value"),
EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2),
]
trainer = self.get_trainer(
callbacks=cbs,
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.train(resume_from_checkpoint=checkpoint)
cbs = [
callback
for callback in trainer.callback_handler.callbacks
if isinstance(callback, (EarlyStoppingCallback, MyTestTrainerCallback))
]
assert len(cbs) == 2
my_test, early_stopping = cbs
assert early_stopping.early_stopping_patience == 5
assert early_stopping.early_stopping_threshold == 0.2
assert my_test.my_test_state == "test"
def test_stateful_duplicate_callbacks(self):
# Use something with non-defaults
cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")]
trainer = self.get_trainer(
callbacks=cbs,
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.train(resume_from_checkpoint=checkpoint)
cbs = [
callback
for callback in trainer.callback_handler.callbacks
if isinstance(callback, MyTestExportableCallback)
]
assert len(cbs) == 2
assert cbs[0].my_test_state == "first"
assert cbs[1].my_test_state == "second"
def test_missing_stateful_callback(self):
cb = EarlyStoppingCallback()
trainer = self.get_trainer(
callbacks=[cb],
load_best_model_at_end=True,
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()
# Create a new trainer with defaults
trainer = self.get_trainer(
save_strategy="steps",
eval_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
# warning should be emitted for not-present callbacks
with patch("transformers.trainer.logger.warning") as warn_mock:
trainer.train(resume_from_checkpoint=checkpoint)
assert "EarlyStoppingCallback" in warn_mock.call_args[0][0]
def test_stateful_control(self):
trainer = self.get_trainer(
max_steps=2,
save_strategy="steps",
save_steps=2,
)
trainer.train()
# Load it back in and verify values
trainer = self.get_trainer(max_steps=2, restore_callback_states_from_checkpoint=True)
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
trainer._load_callback_state()
assert trainer.control.should_training_stop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment