"doc/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "98046bb1b89907c4a87f1e5142361301ccbc3a86"
Unverified Commit 8ffc01a7 authored by Colin Brochtrup's avatar Colin Brochtrup Committed by GitHub
Browse files

Add early stopping callback to pytorch trainer (#8581)

* Add early stopping patience and minimum threshold metric must improve to prevent early stopping to pytorch trainer

* Add early stopping test

* Set patience counter to 0 if best metric not defined yet

* Make early stopping a callback. Add callback event for updating the best metric for early stopping callback to trigger on.

* Run make style

* make funciton name sensible

* Improve new argument docstring wording and hope that flakey CI test passes.

* Use on_evaluation callback instead of custom. Remove some debug printing

* Move early stopping arguments and state into early stopping callback

* Run make style

* Remove old code

* Fix docs formatting. make style went rogue on me.

* Remove copied attributes and fix variable

* Add assertions on training arguments instead of mutating them. Move comment out of public docs.

* Make separate test for early stopping callback. Add test of invalid arguments.

* Run make style... I remembered before CI this time!

* appease flake8

* Add EarlyStoppingCallback to callback docs

* Make docstring EarlyStoppingCallabck match other callbacks.

* Fix typo in docs
parent 367f497d
...@@ -44,6 +44,8 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the ...@@ -44,6 +44,8 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
.. autoclass:: transformers.ProgressCallback .. autoclass:: transformers.ProgressCallback
.. autoclass:: transformers.EarlyStoppingCallback
.. autoclass:: transformers.integrations.TensorBoardCallback .. autoclass:: transformers.integrations.TensorBoardCallback
.. autoclass:: transformers.integrations.WandbCallback .. autoclass:: transformers.integrations.WandbCallback
......
...@@ -253,6 +253,7 @@ else: ...@@ -253,6 +253,7 @@ else:
# Trainer # Trainer
from .trainer_callback import ( from .trainer_callback import (
DefaultFlowCallback, DefaultFlowCallback,
EarlyStoppingCallback,
PrinterCallback, PrinterCallback,
ProgressCallback, ProgressCallback,
TrainerCallback, TrainerCallback,
......
...@@ -21,6 +21,7 @@ import json ...@@ -21,6 +21,7 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy from .trainer_utils import EvaluationStrategy
...@@ -475,3 +476,62 @@ class PrinterCallback(TrainerCallback): ...@@ -475,3 +476,62 @@ class PrinterCallback(TrainerCallback):
_ = logs.pop("total_flos", None) _ = logs.pop("total_flos", None)
if state.is_local_process_zero: if state.is_local_process_zero:
print(logs) print(logs)
class EarlyStoppingCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that handles early stopping.
Args:
early_stopping_patience (:obj:`int`):
Use with :obj:`metric_for_best_model` to stop training when the specified metric worsens for
:obj:`early_stopping_patience` evaluation calls.
early_stopping_threshold(:obj:`float`, `optional`):
Use with TrainingArguments :obj:`metric_for_best_model` and :obj:`early_stopping_patience` to denote how
much the specified metric must improve to satisfy early stopping conditions. `
This callback depends on :class:`~transformers.TrainingArguments` argument `load_best_model_at_end` functionality
to set best_metric in :class:`~transformers.TrainerState`.
"""
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
self.early_stopping_patience_counter = 0
def check_metric_value(self, args, state, control, metric_value):
# best_metric is set by code for load_best_model
operator = np.greater if args.greater_is_better else np.less
if state.best_metric is None or (
operator(metric_value, state.best_metric)
and abs(metric_value - state.best_metric) > self.early_stopping_threshold
):
self.early_stopping_patience_counter = 0
else:
self.early_stopping_patience_counter += 1
def on_train_begin(self, args, state, control, **kwargs):
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
assert (
args.metric_for_best_model is not None
), "EarlyStoppingCallback requires metric_for_best_model is defined"
assert (
args.evaluation_strategy != EvaluationStrategy.NO
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
def on_evaluate(self, args, state, control, metrics, **kwargs):
metric_to_check = args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics.get(metric_to_check)
if metric_value is None:
logger.warning(
f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
)
return
self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True
...@@ -42,6 +42,7 @@ if is_torch_available(): ...@@ -42,6 +42,7 @@ if is_torch_available():
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
EarlyStoppingCallback,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
GPT2Config, GPT2Config,
...@@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase):
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs)) self.assertEqual(train_output.global_step, int(self.n_epochs))
def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertLess(train_output.global_step, 20 * 64 / 16)
# Invalid inputs to trainer with early stopping callback result in assertion error
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1))
self.assertEqual(trainer.state.global_step, 0)
try:
trainer.train()
except AssertionError:
self.assertEqual(trainer.state.global_step, 0)
def test_flos_extraction(self): def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1) trainer = get_regression_trainer(learning_rate=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment