"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "ca85b6795002cc0c74d8e97d19fee88ca6ecc98e"
Unverified Commit 8e729af3 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] reset storage in record evaluation callback each time before starting training (#4885)

* Update test_sklearn.py

* Update python_package.yml

* Update python_package.yml

* Update callback.py

* Update callback.py
parent 729ac43c
...@@ -128,15 +128,15 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: ...@@ -128,15 +128,15 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
""" """
if not isinstance(eval_result, dict): if not isinstance(eval_result, dict):
raise TypeError('eval_result should be a dictionary') raise TypeError('eval_result should be a dictionary')
eval_result.clear()
def _init(env: CallbackEnv) -> None: def _init(env: CallbackEnv) -> None:
eval_result.clear()
for data_name, eval_name, _, _ in env.evaluation_result_list: for data_name, eval_name, _, _ in env.evaluation_result_list:
eval_result.setdefault(data_name, collections.OrderedDict()) eval_result.setdefault(data_name, collections.OrderedDict())
eval_result[data_name].setdefault(eval_name, []) eval_result[data_name].setdefault(eval_name, [])
def _callback(env: CallbackEnv) -> None: def _callback(env: CallbackEnv) -> None:
if not eval_result: if env.iteration == env.begin_iteration:
_init(env) _init(env)
for data_name, eval_name, result, _ in env.evaluation_result_list: for data_name, eval_name, result, _ in env.evaluation_result_list:
eval_result[data_name][eval_name].append(result) eval_result[data_name][eval_name].append(result)
...@@ -221,7 +221,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -221,7 +221,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
best_score_list: list = [] best_score_list: list = []
cmp_op = [] cmp_op = []
enabled = True enabled = True
inited = False
first_metric = '' first_metric = ''
def _init(env: CallbackEnv) -> None: def _init(env: CallbackEnv) -> None:
...@@ -230,7 +229,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -230,7 +229,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
nonlocal best_score_list nonlocal best_score_list
nonlocal cmp_op nonlocal cmp_op
nonlocal enabled nonlocal enabled
nonlocal inited
nonlocal first_metric nonlocal first_metric
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting")) in _ConfigAliases.get("boosting"))
...@@ -249,7 +247,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -249,7 +247,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
best_iter = [] best_iter = []
best_score_list = [] best_score_list = []
cmp_op = [] cmp_op = []
inited = True
first_metric = '' first_metric = ''
n_metrics = len(set(m[1] for m in env.evaluation_result_list)) n_metrics = len(set(m[1] for m in env.evaluation_result_list))
...@@ -293,7 +290,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -293,7 +290,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
nonlocal best_iter nonlocal best_iter
nonlocal best_score_list nonlocal best_score_list
nonlocal inited
if env.iteration == env.end_iteration - 1: if env.iteration == env.end_iteration - 1:
if verbose: if verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]]) best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
...@@ -301,7 +297,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -301,7 +297,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}') f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
if first_metric_only: if first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}") _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
inited = False
raise EarlyStopException(best_iter[i], best_score_list[i]) raise EarlyStopException(best_iter[i], best_score_list[i])
def _callback(env: CallbackEnv) -> None: def _callback(env: CallbackEnv) -> None:
...@@ -310,9 +305,8 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -310,9 +305,8 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
nonlocal best_score_list nonlocal best_score_list
nonlocal cmp_op nonlocal cmp_op
nonlocal enabled nonlocal enabled
nonlocal inited
nonlocal first_metric nonlocal first_metric
if not inited: if env.iteration == env.begin_iteration:
_init(env) _init(env)
if not enabled: if not enabled:
return return
...@@ -336,7 +330,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -336,7 +330,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
_log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}") _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
if first_metric_only: if first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}") _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
inited = False
raise EarlyStopException(best_iter[i], best_score_list[i]) raise EarlyStopException(best_iter[i], best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i) _final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30 # type: ignore _callback.order = 30 # type: ignore
......
...@@ -294,20 +294,23 @@ def test_stacking_regressor(): ...@@ -294,20 +294,23 @@ def test_stacking_regressor():
def test_grid_search(): def test_grid_search():
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
y = y.astype(str) # utilize label encoder at it's max power y = y.astype(str) # utilize label encoder at it's max power
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
random_state=42) X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1,
random_state=42)
params = dict(subsample=0.8, params = dict(subsample=0.8,
subsample_freq=1) subsample_freq=1)
grid_params = dict(boosting_type=['rf', 'gbdt'], grid_params = dict(boosting_type=['rf', 'gbdt'],
n_estimators=[4, 6], n_estimators=[4, 6],
reg_alpha=[0.01, 0.005]) reg_alpha=[0.01, 0.005])
fit_params = dict(eval_set=[(X_val, y_val)], evals_result = {}
eval_metric=constant_metric, fit_params = dict(
callbacks=[lgb.early_stopping(2)]) eval_set=[(X_val, y_val)],
grid = GridSearchCV(estimator=lgb.LGBMClassifier(**params), param_grid=grid_params, eval_metric=constant_metric,
cv=2) callbacks=[
lgb.early_stopping(2),
lgb.record_evaluation(evals_result)
]
)
grid = GridSearchCV(estimator=lgb.LGBMClassifier(**params), param_grid=grid_params, cv=2)
grid.fit(X_train, y_train, **fit_params) grid.fit(X_train, y_train, **fit_params)
score = grid.score(X_test, y_test) # utilizes GridSearchCV default refit=True score = grid.score(X_test, y_test) # utilizes GridSearchCV default refit=True
assert grid.best_params_['boosting_type'] in ['rf', 'gbdt'] assert grid.best_params_['boosting_type'] in ['rf', 'gbdt']
...@@ -319,6 +322,7 @@ def test_grid_search(): ...@@ -319,6 +322,7 @@ def test_grid_search():
assert grid.best_estimator_.best_score_['valid_0']['error'] == 0 assert grid.best_estimator_.best_score_['valid_0']['error'] == 0
assert score >= 0.2 assert score >= 0.2
assert score <= 1. assert score <= 1.
assert evals_result == grid.best_estimator_.evals_result_
def test_random_search(): def test_random_search():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment