Unverified Commit 5f79626f authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix type annotations for eval result tracking (#5793)

parent 42a42670
...@@ -15,6 +15,10 @@ __all__ = [ ...@@ -15,6 +15,10 @@ __all__ = [
_EvalResultDict = Dict[str, Dict[str, List[Any]]] _EvalResultDict = Dict[str, Dict[str, List[Any]]]
_EvalResultTuple = Union[ _EvalResultTuple = Union[
_LGBM_BoosterEvalMethodResultType,
Tuple[str, str, float, bool, float]
]
_ListOfEvalResultTuples = Union[
List[_LGBM_BoosterEvalMethodResultType], List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]] List[Tuple[str, str, float, bool, float]]
] ]
...@@ -23,7 +27,7 @@ _EvalResultTuple = Union[ ...@@ -23,7 +27,7 @@ _EvalResultTuple = Union[
class EarlyStopException(Exception): class EarlyStopException(Exception):
"""Exception of early stopping.""" """Exception of early stopping."""
def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None: def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
"""Create early stopping exception. """Create early stopping exception.
Parameters Parameters
...@@ -55,7 +59,7 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str: ...@@ -55,7 +59,7 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
return f"{value[0]}'s {value[1]}: {value[2]:g}" return f"{value[0]}'s {value[1]}: {value[2]:g}"
elif len(value) == 5: elif len(value) == 5:
if show_stdv: if show_stdv:
return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}" return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}" # type: ignore[misc]
else: else:
return f"{value[0]}'s {value[1]}: {value[2]:g}" return f"{value[0]}'s {value[1]}: {value[2]:g}"
else: else:
...@@ -256,7 +260,7 @@ class _EarlyStoppingCallback: ...@@ -256,7 +260,7 @@ class _EarlyStoppingCallback:
def _reset_storages(self) -> None: def _reset_storages(self) -> None:
self.best_score: List[float] = [] self.best_score: List[float] = []
self.best_iter: List[int] = [] self.best_iter: List[int] = []
self.best_score_list: List[Union[_EvalResultTuple, None]] = [] self.best_score_list: List[_ListOfEvalResultTuples] = []
self.cmp_op: List[Callable[[float, float], bool]] = [] self.cmp_op: List[Callable[[float, float], bool]] = []
self.first_metric = '' self.first_metric = ''
...@@ -327,7 +331,6 @@ class _EarlyStoppingCallback: ...@@ -327,7 +331,6 @@ class _EarlyStoppingCallback:
self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1] self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas): for eval_ret, delta in zip(env.evaluation_result_list, deltas):
self.best_iter.append(0) self.best_iter.append(0)
self.best_score_list.append(None)
if eval_ret[3]: # greater is better if eval_ret[3]: # greater is better
self.best_score.append(float('-inf')) self.best_score.append(float('-inf'))
self.cmp_op.append(partial(self._gt_delta, delta=delta)) self.cmp_op.append(partial(self._gt_delta, delta=delta))
...@@ -350,12 +353,17 @@ class _EarlyStoppingCallback: ...@@ -350,12 +353,17 @@ class _EarlyStoppingCallback:
self._init(env) self._init(env)
if not self.enabled: if not self.enabled:
return return
# self.best_score_list is initialized to an empty list
first_time_updating_best_score_list = (self.best_score_list == [])
for i in range(len(env.evaluation_result_list)): for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] score = env.evaluation_result_list[i][2]
if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]): if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
self.best_score[i] = score self.best_score[i] = score
self.best_iter[i] = env.iteration self.best_iter[i] = env.iteration
self.best_score_list[i] = env.evaluation_result_list if first_time_updating_best_score_list:
self.best_score_list.append(env.evaluation_result_list)
else:
self.best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1") # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ") eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]: if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
......
...@@ -11,8 +11,9 @@ import numpy as np ...@@ -11,8 +11,9 @@ import numpy as np
from . import callback from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType, _LGBM_BoosterEvalMethodResultType, _LGBM_CategoricalFeatureConfiguration,
_LGBM_FeatureNameConfiguration, _log_warning) _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType, _LGBM_FeatureNameConfiguration,
_log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
__all__ = [ __all__ = [
...@@ -257,7 +258,7 @@ def train( ...@@ -257,7 +258,7 @@ def train(
booster.update(fobj=fobj) booster.update(fobj=fobj)
evaluation_result_list = [] evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] = []
# check evaluation result. # check evaluation result.
if valid_sets is not None: if valid_sets is not None:
if is_valid_contain_train: if is_valid_contain_train:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment