Unverified Commit 581d53c6 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Booster eval methods (#5433)

* [python-package] add type hints on Booster eval methods

* remove unnecessary changes

* fix hints
parent 39eb041f
...@@ -20,6 +20,9 @@ import scipy.sparse ...@@ -20,6 +20,9 @@ import scipy.sparse
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series
from .libpath import find_lib_path from .libpath import find_lib_path
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
ZERO_THRESHOLD = 1e-35 ZERO_THRESHOLD = 1e-35
...@@ -2635,6 +2638,16 @@ _LGBM_CustomObjectiveFunction = Callable[ ...@@ -2635,6 +2638,16 @@ _LGBM_CustomObjectiveFunction = Callable[
[np.ndarray, Dataset], [np.ndarray, Dataset],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray]
] ]
_LGBM_CustomEvalFunction = Union[
Callable[
[np.ndarray, Dataset],
_LGBM_EvalFunctionResultType
],
Callable[
[np.ndarray, Dataset],
List[_LGBM_EvalFunctionResultType]
]
]
class Booster: class Booster:
...@@ -3273,7 +3286,12 @@ class Booster: ...@@ -3273,7 +3286,12 @@ class Booster:
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
def eval(self, data, name, feval=None): def eval(
self,
data: Dataset,
name: str,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for data. """Evaluate for data.
Parameters Parameters
...@@ -3304,7 +3322,7 @@ class Booster: ...@@ -3304,7 +3322,7 @@ class Booster:
Returns Returns
------- -------
result : list result : list
List with evaluation results. List with (dataset_name, eval_name, eval_result, is_higher_better) tuples.
""" """
if not isinstance(data, Dataset): if not isinstance(data, Dataset):
raise TypeError("Can only eval for Dataset instance") raise TypeError("Can only eval for Dataset instance")
...@@ -3323,7 +3341,10 @@ class Booster: ...@@ -3323,7 +3341,10 @@ class Booster:
return self.__inner_eval(name, data_idx, feval) return self.__inner_eval(name, data_idx, feval)
def eval_train(self, feval=None): def eval_train(
self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for training data. """Evaluate for training data.
Parameters Parameters
...@@ -3350,11 +3371,14 @@ class Booster: ...@@ -3350,11 +3371,14 @@ class Booster:
Returns Returns
------- -------
result : list result : list
List with evaluation results. List with (train_dataset_name, eval_name, eval_result, is_higher_better) tuples.
""" """
return self.__inner_eval(self._train_data_name, 0, feval) return self.__inner_eval(self._train_data_name, 0, feval)
def eval_valid(self, feval=None): def eval_valid(
self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for validation data. """Evaluate for validation data.
Parameters Parameters
...@@ -3381,7 +3405,7 @@ class Booster: ...@@ -3381,7 +3405,7 @@ class Booster:
Returns Returns
------- -------
result : list result : list
List with evaluation results. List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples.
""" """
return [item for i in range(1, self.__num_dataset) return [item for i in range(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)] for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
...@@ -3987,7 +4011,12 @@ class Booster: ...@@ -3987,7 +4011,12 @@ class Booster:
else: else:
return hist, bin_edges return hist, bin_edges
def __inner_eval(self, data_name, data_idx, feval=None): def __inner_eval(
self,
data_name: str,
data_idx: int,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate training or validation data.""" """Evaluate training or validation data."""
if data_idx >= self.__num_dataset: if data_idx >= self.__num_dataset:
raise ValueError("Data_idx should be smaller than number of dataset") raise ValueError("Data_idx should be smaller than number of dataset")
......
...@@ -4,10 +4,10 @@ import collections ...@@ -4,10 +4,10 @@ import collections
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union from typing import Any, Callable, Dict, List, Tuple, Union
from .basic import _ConfigAliases, _log_info, _log_warning from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
_EvalResultTuple = Union[ _EvalResultTuple = Union[
List[Tuple[str, str, float, bool]], List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]] List[Tuple[str, str, float, bool, float]]
] ]
......
...@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -6,7 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _LGBM_EvalFunctionResultType,
_log_warning)
from .callback import record_evaluation from .callback import record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray, from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
...@@ -14,8 +15,6 @@ from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite ...@@ -14,8 +15,6 @@ from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite
dt_DataTable, pd_DataFrame) dt_DataTable, pd_DataFrame)
from .engine import train from .engine import train
_EvalResultType = Tuple[str, float, bool]
_LGBM_ScikitCustomObjectiveFunction = Union[ _LGBM_ScikitCustomObjectiveFunction = Union[
Callable[ Callable[
[np.ndarray, np.ndarray], [np.ndarray, np.ndarray],
...@@ -33,15 +32,15 @@ _LGBM_ScikitCustomObjectiveFunction = Union[ ...@@ -33,15 +32,15 @@ _LGBM_ScikitCustomObjectiveFunction = Union[
_LGBM_ScikitCustomEvalFunction = Union[ _LGBM_ScikitCustomEvalFunction = Union[
Callable[ Callable[
[np.ndarray, np.ndarray], [np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]] Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
], ],
Callable[ Callable[
[np.ndarray, np.ndarray, np.ndarray], [np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]] Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
], ],
Callable[ Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]] Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
], ],
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment