Unverified Commit 5f79626f authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix type annotations for eval result tracking (#5793)

parent 42a42670
......@@ -15,6 +15,10 @@ __all__ = [
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
_EvalResultTuple = Union[
_LGBM_BoosterEvalMethodResultType,
Tuple[str, str, float, bool, float]
]
_ListOfEvalResultTuples = Union[
List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]]
]
......@@ -23,7 +27,7 @@ _EvalResultTuple = Union[
class EarlyStopException(Exception):
"""Exception of early stopping."""
def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
"""Create early stopping exception.
Parameters
......@@ -55,7 +59,7 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
return f"{value[0]}'s {value[1]}: {value[2]:g}"
elif len(value) == 5:
if show_stdv:
return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}" # type: ignore[misc]
else:
return f"{value[0]}'s {value[1]}: {value[2]:g}"
else:
......@@ -256,7 +260,7 @@ class _EarlyStoppingCallback:
def _reset_storages(self) -> None:
self.best_score: List[float] = []
self.best_iter: List[int] = []
self.best_score_list: List[Union[_EvalResultTuple, None]] = []
self.best_score_list: List[_ListOfEvalResultTuples] = []
self.cmp_op: List[Callable[[float, float], bool]] = []
self.first_metric = ''
......@@ -327,7 +331,6 @@ class _EarlyStoppingCallback:
self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
self.best_iter.append(0)
self.best_score_list.append(None)
if eval_ret[3]: # greater is better
self.best_score.append(float('-inf'))
self.cmp_op.append(partial(self._gt_delta, delta=delta))
......@@ -350,12 +353,17 @@ class _EarlyStoppingCallback:
self._init(env)
if not self.enabled:
return
# self.best_score_list is initialized to an empty list
first_time_updating_best_score_list = (self.best_score_list == [])
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
self.best_score[i] = score
self.best_iter[i] = env.iteration
self.best_score_list[i] = env.evaluation_result_list
if first_time_updating_best_score_list:
self.best_score_list.append(env.evaluation_result_list)
else:
self.best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
......
......@@ -11,8 +11,9 @@ import numpy as np
from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType,
_LGBM_FeatureNameConfiguration, _log_warning)
_LGBM_BoosterEvalMethodResultType, _LGBM_CategoricalFeatureConfiguration,
_LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType, _LGBM_FeatureNameConfiguration,
_log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
__all__ = [
......@@ -257,7 +258,7 @@ def train(
booster.update(fobj=fobj)
evaluation_result_list = []
evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] = []
# check evaluation result.
if valid_sets is not None:
if is_valid_contain_train:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment