Unverified Commit 1a6e6ff9 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors related to eval result tuples (#6097)

parent 921479b9
......@@ -54,6 +54,7 @@ _ctypes_float_array = Union[
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_BoosterEvalMethodResultWithStandardDeviationType = Tuple[str, str, float, bool, float]
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], "Literal['auto']"]
_LGBM_FeatureNameConfiguration = Union[List[str], "Literal['auto']"]
_LGBM_GroupType = Union[
......
......@@ -3,9 +3,10 @@
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
from .basic import (Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType,
_LGBM_BoosterEvalMethodResultWithStandardDeviationType, _log_info, _log_warning)
if TYPE_CHECKING:
from .engine import CVBooster
......@@ -20,11 +21,11 @@ __all__ = [
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
_EvalResultTuple = Union[
_LGBM_BoosterEvalMethodResultType,
Tuple[str, str, float, bool, float]
_LGBM_BoosterEvalMethodResultWithStandardDeviationType
]
_ListOfEvalResultTuples = Union[
List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]]
List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]
]
......@@ -54,7 +55,7 @@ class CallbackEnv:
iteration: int
begin_iteration: int
end_iteration: int
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]
evaluation_result_list: Optional[_ListOfEvalResultTuples]
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
......
......@@ -11,9 +11,9 @@ import numpy as np
from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_BoosterEvalMethodResultType, _LGBM_CategoricalFeatureConfiguration,
_LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType, _LGBM_FeatureNameConfiguration,
_log_warning)
_LGBM_BoosterEvalMethodResultType, _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType,
_LGBM_FeatureNameConfiguration, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
__all__ = [
......@@ -519,8 +519,8 @@ def _make_n_folds(
def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]:
raw_results: List[List[_LGBM_BoosterEvalMethodResultType]]
) -> List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]:
"""Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {}
......@@ -530,7 +530,7 @@ def _agg_cv_result(
metric_type[key] = one_line[3]
cvmap.setdefault(key, [])
cvmap[key].append(one_line[2])
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
return [('cv_agg', k, float(np.mean(v)), metric_type[k], float(np.std(v))) for k, v in cvmap.items()]
def cv(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment