Unverified Commit d8274346 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] reset storages in early stopping callback after finishing training (#4868)

parent 00f87c52
......@@ -221,6 +221,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
best_score_list: list = []
cmp_op = []
enabled = True
inited = False
first_metric = ''
def _init(env: CallbackEnv) -> None:
......@@ -229,6 +230,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal inited
nonlocal first_metric
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
......@@ -242,6 +244,14 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
if verbose:
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
# reset storages
best_score = []
best_iter = []
best_score_list = []
cmp_op = []
inited = True
first_metric = ''
n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(min_delta, list):
......@@ -283,6 +293,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
nonlocal best_iter
nonlocal best_score_list
nonlocal inited
if env.iteration == env.end_iteration - 1:
if verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
......@@ -290,6 +301,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
if first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
inited = False
raise EarlyStopException(best_iter[i], best_score_list[i])
def _callback(env: CallbackEnv) -> None:
......@@ -298,8 +310,9 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal inited
nonlocal first_metric
if not cmp_op:
if not inited:
_init(env)
if not enabled:
return
......@@ -323,6 +336,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
_log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
if first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
inited = False
raise EarlyStopException(best_iter[i], best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30 # type: ignore
......
......@@ -287,7 +287,7 @@ def test_grid_search():
reg_alpha=[0.01, 0.005])
fit_params = dict(eval_set=[(X_val, y_val)],
eval_metric=constant_metric,
early_stopping_rounds=2)
callbacks=[lgb.early_stopping(2)])
grid = GridSearchCV(estimator=lgb.LGBMClassifier(**params), param_grid=grid_params,
cv=2)
grid.fit(X_train, y_train, **fit_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment