Unverified Commit 581d53c6 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Booster eval methods (#5433)

* [python-package] add type hints on Booster eval methods

* remove unnecessary changes

* fix hints
parent 39eb041f
......@@ -20,6 +20,9 @@ import scipy.sparse
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series
from .libpath import find_lib_path
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
ZERO_THRESHOLD = 1e-35
......@@ -2635,6 +2638,16 @@ _LGBM_CustomObjectiveFunction = Callable[
[np.ndarray, Dataset],
Tuple[np.ndarray, np.ndarray]
]
_LGBM_CustomEvalFunction = Union[
Callable[
[np.ndarray, Dataset],
_LGBM_EvalFunctionResultType
],
Callable[
[np.ndarray, Dataset],
List[_LGBM_EvalFunctionResultType]
]
]
class Booster:
......@@ -3273,7 +3286,12 @@ class Booster:
ctypes.byref(ret)))
return ret.value
def eval(self, data, name, feval=None):
def eval(
self,
data: Dataset,
name: str,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for data.
Parameters
......@@ -3304,7 +3322,7 @@ class Booster:
Returns
-------
result : list
List with evaluation results.
List with (dataset_name, eval_name, eval_result, is_higher_better) tuples.
"""
if not isinstance(data, Dataset):
raise TypeError("Can only eval for Dataset instance")
......@@ -3323,7 +3341,10 @@ class Booster:
return self.__inner_eval(name, data_idx, feval)
def eval_train(self, feval=None):
def eval_train(
self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for training data.
Parameters
......@@ -3350,11 +3371,14 @@ class Booster:
Returns
-------
result : list
List with evaluation results.
List with (train_dataset_name, eval_name, eval_result, is_higher_better) tuples.
"""
return self.__inner_eval(self._train_data_name, 0, feval)
def eval_valid(self, feval=None):
def eval_valid(
self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for validation data.
Parameters
......@@ -3381,7 +3405,7 @@ class Booster:
Returns
-------
result : list
List with evaluation results.
List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples.
"""
return [item for i in range(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
......@@ -3987,7 +4011,12 @@ class Booster:
else:
return hist, bin_edges
def __inner_eval(self, data_name, data_idx, feval=None):
def __inner_eval(
self,
data_name: str,
data_idx: int,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate training or validation data."""
if data_idx >= self.__num_dataset:
raise ValueError("Data_idx should be smaller than number of dataset")
......
......@@ -4,10 +4,10 @@ import collections
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
from .basic import _ConfigAliases, _log_info, _log_warning
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
_EvalResultTuple = Union[
List[Tuple[str, str, float, bool]],
List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]]
]
......
......@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _LGBM_EvalFunctionResultType,
_log_warning)
from .callback import record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
......@@ -14,8 +15,6 @@ from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite
dt_DataTable, pd_DataFrame)
from .engine import train
_EvalResultType = Tuple[str, float, bool]
_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[np.ndarray, np.ndarray],
......@@ -33,15 +32,15 @@ _LGBM_ScikitCustomObjectiveFunction = Union[
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment