Unverified Commit d45dca70 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] reorganize early stopping callback (#6114)

parent f175cebd
...@@ -229,7 +229,12 @@ class _ResetParameterCallback: ...@@ -229,7 +229,12 @@ class _ResetParameterCallback:
if new_param != env.params.get(key, None): if new_param != env.params.get(key, None):
new_parameters[key] = new_param new_parameters[key] = new_param
if new_parameters: if new_parameters:
if isinstance(env.model, Booster):
env.model.reset_parameter(new_parameters) env.model.reset_parameter(new_parameters)
else:
# CVBooster holds a list of Booster objects, each needs to be updated
for booster in env.model.boosters:
booster.reset_parameter(new_parameters)
env.params.update(new_parameters) env.params.update(new_parameters)
...@@ -267,6 +272,10 @@ class _EarlyStoppingCallback: ...@@ -267,6 +272,10 @@ class _EarlyStoppingCallback:
verbose: bool = True, verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0 min_delta: Union[float, List[float]] = 0.0
) -> None: ) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
self.order = 30 self.order = 30
self.before_iteration = False self.before_iteration = False
...@@ -291,33 +300,46 @@ class _EarlyStoppingCallback: ...@@ -291,33 +300,46 @@ class _EarlyStoppingCallback:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta return curr_score < best_score - delta
def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool: def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name """Check, by name, if a given Dataset is the training data."""
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
# and those metrics are considered for early stopping
if ds_name == "cv_agg" and eval_name == "train":
return True
# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
if isinstance(env.model, Booster) and ds_name == env.model._train_data_name:
return True
return False
def _init(self, env: CallbackEnv) -> None: def _init(self, env: CallbackEnv) -> None:
if env.evaluation_result_list is None or env.evaluation_result_list == []: if env.evaluation_result_list is None or env.evaluation_result_list == []:
raise ValueError( raise ValueError(
"For early stopping, at least one dataset and eval metric is required for evaluation" "For early stopping, at least one dataset and eval metric is required for evaluation"
) )
is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting")) is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
if is_dart:
self.enabled = False
_log_warning('Early stopping is not available in dart mode')
return
# validation sets are guaranteed to not be identical to the training data in cv()
if isinstance(env.model, Booster):
only_train_set = ( only_train_set = (
len(env.evaluation_result_list) == 1 len(env.evaluation_result_list) == 1
and self._is_train_set( and self._is_train_set(
ds_name=env.evaluation_result_list[0][0], ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0], eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name) env=env
) )
self.enabled = not is_dart and not only_train_set )
if not self.enabled: if only_train_set:
if is_dart: self.enabled = False
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.') _log_warning('Only training set found, disabling early stopping.')
return return
if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
if self.verbose: if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds") _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
...@@ -395,7 +417,11 @@ class _EarlyStoppingCallback: ...@@ -395,7 +417,11 @@ class _EarlyStoppingCallback:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ") eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]: if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping continue # use only the first metric for early stopping
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name): if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0],
env=env
):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds: elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose: if self.verbose:
......
...@@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer): ...@@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer):
assert callback.stopping_rounds == rounds assert callback.stopping_rounds == rounds
def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
lgb.early_stopping(stopping_rounds="neverrrr")
@pytest.mark.parametrize('serializer', SERIALIZERS) @pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer): def test_log_evaluation_callback_is_picklable(serializer):
periods = 42 periods = 42
......
...@@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object ...@@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object
def test_train_raises_informative_error_for_params_of_wrong_type(): def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression() X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"} params = {"num_leaves": "too-many"}
dtrain = lgb.Dataset(X, label=y) dtrain = lgb.Dataset(X, label=y)
with pytest.raises(lgb.basic.LightGBMError, match="Parameter early_stopping_round should be of type int, got \"too-many\""): with pytest.raises(lgb.basic.LightGBMError, match="Parameter num_leaves should be of type int, got \"too-many\""):
lgb.train(params, dtrain) lgb.train(params, dtrain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment