"src/vscode:/vscode.git/clone" did not exist on "a97c444b4cf9d2755bd888911ce65ace1fe13e4b"
Unverified Commit d45dca70 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] reorganize early stopping callback (#6114)

parent f175cebd
......@@ -229,7 +229,12 @@ class _ResetParameterCallback:
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
if isinstance(env.model, Booster):
env.model.reset_parameter(new_parameters)
else:
# CVBooster holds a list of Booster objects, each needs to be updated
for booster in env.model.boosters:
booster.reset_parameter(new_parameters)
env.params.update(new_parameters)
......@@ -267,6 +272,10 @@ class _EarlyStoppingCallback:
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0
) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
self.order = 30
self.before_iteration = False
......@@ -291,33 +300,46 @@ class _EarlyStoppingCallback:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta
def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name
def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
"""Check, by name, if a given Dataset is the training data."""
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
# and those metrics are considered for early stopping
if ds_name == "cv_agg" and eval_name == "train":
return True
# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
if isinstance(env.model, Booster) and ds_name == env.model._train_data_name:
return True
return False
def _init(self, env: CallbackEnv) -> None:
if env.evaluation_result_list is None or env.evaluation_result_list == []:
raise ValueError(
"For early stopping, at least one dataset and eval metric is required for evaluation"
)
is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
if is_dart:
self.enabled = False
_log_warning('Early stopping is not available in dart mode')
return
# validation sets are guaranteed to not be identical to the training data in cv()
if isinstance(env.model, Booster):
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
env=env
)
self.enabled = not is_dart and not only_train_set
if not self.enabled:
if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
)
if only_train_set:
self.enabled = False
_log_warning('Only training set found, disabling early stopping.')
return
if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
......@@ -395,7 +417,11 @@ class _EarlyStoppingCallback:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0],
env=env
):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
......
......@@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer):
assert callback.stopping_rounds == rounds
def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
lgb.early_stopping(stopping_rounds="neverrrr")
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
periods = 42
......
......@@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object
def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"}
params = {"num_leaves": "too-many"}
dtrain = lgb.Dataset(X, label=y)
with pytest.raises(lgb.basic.LightGBMError, match="Parameter early_stopping_round should be of type int, got \"too-many\""):
with pytest.raises(lgb.basic.LightGBMError, match="Parameter num_leaves should be of type int, got \"too-many\""):
lgb.train(params, dtrain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment