Unverified Commit aa413f7e authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] remove workaround for `UnboundLocalError` in early stopping callback for Python 2 (#4855)

parent 12915d58
...@@ -220,13 +220,19 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -220,13 +220,19 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
best_iter = [] best_iter = []
best_score_list: list = [] best_score_list: list = []
cmp_op = [] cmp_op = []
enabled = [True] enabled = True
first_metric = [''] first_metric = ''
def _init(env: CallbackEnv) -> None: def _init(env: CallbackEnv) -> None:
enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting")) in _ConfigAliases.get("boosting"))
if not enabled[0]: if not enabled:
_log_warning('Early stopping is not available in dart mode') _log_warning('Early stopping is not available in dart mode')
return return
if not env.evaluation_result_list: if not env.evaluation_result_list:
...@@ -263,7 +269,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -263,7 +269,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
deltas = [min_delta] * n_datasets * n_metrics deltas = [min_delta] * n_datasets * n_metrics
# split is needed for "<dataset type> <metric>" case (e.g. "train l1") # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas): for eval_ret, delta in zip(env.evaluation_result_list, deltas):
best_iter.append(0) best_iter.append(0)
best_score_list.append(None) best_score_list.append(None)
...@@ -275,6 +281,8 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -275,6 +281,8 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
cmp_op.append(partial(_lt_delta, delta=delta)) cmp_op.append(partial(_lt_delta, delta=delta))
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
nonlocal best_iter
nonlocal best_score_list
if env.iteration == env.end_iteration - 1: if env.iteration == env.end_iteration - 1:
if verbose: if verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]]) best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
...@@ -285,9 +293,15 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -285,9 +293,15 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
raise EarlyStopException(best_iter[i], best_score_list[i]) raise EarlyStopException(best_iter[i], best_score_list[i])
def _callback(env: CallbackEnv) -> None: def _callback(env: CallbackEnv) -> None:
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
if not cmp_op: if not cmp_op:
_init(env) _init(env)
if not enabled[0]: if not enabled:
return return
for i in range(len(env.evaluation_result_list)): for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] score = env.evaluation_result_list[i][2]
...@@ -297,7 +311,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -297,7 +311,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
best_score_list[i] = env.evaluation_result_list best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1") # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ") eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if first_metric_only and first_metric[0] != eval_name_splitted[-1]: if first_metric_only and first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train" if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)): or env.evaluation_result_list[i][0] == env.model._train_data_name)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment