Unverified Commit 1a6e6ff9 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors related to eval result tuples (#6097)

parent 921479b9
...@@ -54,6 +54,7 @@ _ctypes_float_array = Union[ ...@@ -54,6 +54,7 @@ _ctypes_float_array = Union[
_LGBM_EvalFunctionResultType = Tuple[str, float, bool] _LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]] _LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool] _LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_BoosterEvalMethodResultWithStandardDeviationType = Tuple[str, str, float, bool, float]
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], "Literal['auto']"] _LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], "Literal['auto']"]
_LGBM_FeatureNameConfiguration = Union[List[str], "Literal['auto']"] _LGBM_FeatureNameConfiguration = Union[List[str], "Literal['auto']"]
_LGBM_GroupType = Union[ _LGBM_GroupType = Union[
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning from .basic import (Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType,
_LGBM_BoosterEvalMethodResultWithStandardDeviationType, _log_info, _log_warning)
if TYPE_CHECKING: if TYPE_CHECKING:
from .engine import CVBooster from .engine import CVBooster
...@@ -20,11 +21,11 @@ __all__ = [ ...@@ -20,11 +21,11 @@ __all__ = [
_EvalResultDict = Dict[str, Dict[str, List[Any]]] _EvalResultDict = Dict[str, Dict[str, List[Any]]]
_EvalResultTuple = Union[ _EvalResultTuple = Union[
_LGBM_BoosterEvalMethodResultType, _LGBM_BoosterEvalMethodResultType,
Tuple[str, str, float, bool, float] _LGBM_BoosterEvalMethodResultWithStandardDeviationType
] ]
_ListOfEvalResultTuples = Union[ _ListOfEvalResultTuples = Union[
List[_LGBM_BoosterEvalMethodResultType], List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]] List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]
] ]
...@@ -54,7 +55,7 @@ class CallbackEnv: ...@@ -54,7 +55,7 @@ class CallbackEnv:
iteration: int iteration: int
begin_iteration: int begin_iteration: int
end_iteration: int end_iteration: int
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]] evaluation_result_list: Optional[_ListOfEvalResultTuples]
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str: def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
......
...@@ -11,9 +11,9 @@ import numpy as np ...@@ -11,9 +11,9 @@ import numpy as np
from . import callback from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_BoosterEvalMethodResultType, _LGBM_CategoricalFeatureConfiguration, _LGBM_BoosterEvalMethodResultType, _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
_LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType, _LGBM_FeatureNameConfiguration, _LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType,
_log_warning) _LGBM_FeatureNameConfiguration, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
__all__ = [ __all__ = [
...@@ -519,8 +519,8 @@ def _make_n_folds( ...@@ -519,8 +519,8 @@ def _make_n_folds(
def _agg_cv_result( def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]] raw_results: List[List[_LGBM_BoosterEvalMethodResultType]]
) -> List[Tuple[str, str, float, bool, float]]: ) -> List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]:
"""Aggregate cross-validation results.""" """Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = OrderedDict() cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {} metric_type: Dict[str, bool] = {}
...@@ -530,7 +530,7 @@ def _agg_cv_result( ...@@ -530,7 +530,7 @@ def _agg_cv_result(
metric_type[key] = one_line[3] metric_type[key] = one_line[3]
cvmap.setdefault(key, []) cvmap.setdefault(key, [])
cvmap[key].append(one_line[2]) cvmap[key].append(one_line[2])
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()] return [('cv_agg', k, float(np.mean(v)), metric_type[k], float(np.std(v))) for k, v in cvmap.items()]
def cv( def cv(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment