"src/objective/vscode:/vscode.git/clone" did not exist on "1d5f46f6e7704c1ce2a82cc1882e920d42cba7a3"
Unverified Commit aa413f7e authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] remove workaround for `UnboundLocalError` in early stopping callback for Python 2 (#4855)

parent 12915d58
......@@ -220,13 +220,19 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
best_iter = []
best_score_list: list = []
cmp_op = []
enabled = [True]
first_metric = ['']
enabled = True
first_metric = ''
def _init(env: CallbackEnv) -> None:
enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
if not enabled[0]:
if not enabled:
_log_warning('Early stopping is not available in dart mode')
return
if not env.evaluation_result_list:
......@@ -263,7 +269,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
deltas = [min_delta] * n_datasets * n_metrics
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
best_iter.append(0)
best_score_list.append(None)
......@@ -275,6 +281,8 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
cmp_op.append(partial(_lt_delta, delta=delta))
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
nonlocal best_iter
nonlocal best_score_list
if env.iteration == env.end_iteration - 1:
if verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
......@@ -285,9 +293,15 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
raise EarlyStopException(best_iter[i], best_score_list[i])
def _callback(env: CallbackEnv) -> None:
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
if not cmp_op:
_init(env)
if not enabled[0]:
if not enabled:
return
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
......@@ -297,7 +311,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
if first_metric_only and first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment