Unverified Commit a5285985 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors about custom eval and metric functions (#5790)

parent 9f035100
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from . import callback from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType,
_LGBM_FeatureNameConfiguration, _log_warning) _LGBM_FeatureNameConfiguration, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
...@@ -22,9 +22,15 @@ __all__ = [ ...@@ -22,9 +22,15 @@ __all__ = [
] ]
_LGBM_CustomMetricFunction = Callable[ _LGBM_CustomMetricFunction = Union[
Callable[
[np.ndarray, Dataset], [np.ndarray, Dataset],
Union[Tuple[str, float, bool], List[Tuple[str, float, bool]]] _LGBM_EvalFunctionResultType,
],
Callable[
[np.ndarray, Dataset],
List[_LGBM_EvalFunctionResultType]
],
] ]
_LGBM_PreprocFunction = Callable[ _LGBM_PreprocFunction = Callable[
......
...@@ -33,32 +33,50 @@ _LGBM_ScikitMatrixLike = Union[ ...@@ -33,32 +33,50 @@ _LGBM_ScikitMatrixLike = Union[
scipy.sparse.spmatrix scipy.sparse.spmatrix
] ]
_LGBM_ScikitCustomObjectiveFunction = Union[ _LGBM_ScikitCustomObjectiveFunction = Union[
# f(labels, preds)
Callable[ Callable[
[np.ndarray, np.ndarray], [Optional[np.ndarray], np.ndarray],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray]
], ],
# f(labels, preds, weights)
Callable[ Callable[
[np.ndarray, np.ndarray, np.ndarray], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray]
], ],
# f(labels, preds, weights, group)
Callable[ Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray]
], ],
] ]
_LGBM_ScikitCustomEvalFunction = Union[ _LGBM_ScikitCustomEvalFunction = Union[
# f(labels, preds)
Callable[ Callable[
[np.ndarray, np.ndarray], [Optional[np.ndarray], np.ndarray],
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] _LGBM_EvalFunctionResultType
], ],
Callable[ Callable[
[np.ndarray, np.ndarray, np.ndarray], [Optional[np.ndarray], np.ndarray],
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] List[_LGBM_EvalFunctionResultType]
], ],
# f(labels, preds, weights)
Callable[ Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] _LGBM_EvalFunctionResultType
], ],
Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
List[_LGBM_EvalFunctionResultType]
],
# f(labels, preds, weights, group)
Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
_LGBM_EvalFunctionResultType
],
Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
List[_LGBM_EvalFunctionResultType]
]
] ]
_LGBM_ScikitEvalMetricType = Union[ _LGBM_ScikitEvalMetricType = Union[
str, str,
...@@ -135,11 +153,11 @@ class _ObjectiveFunctionWrapper: ...@@ -135,11 +153,11 @@ class _ObjectiveFunctionWrapper:
labels = dataset.get_label() labels = dataset.get_label()
argc = len(signature(self.func).parameters) argc = len(signature(self.func).parameters)
if argc == 2: if argc == 2:
grad, hess = self.func(labels, preds) grad, hess = self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3: elif argc == 3:
grad, hess = self.func(labels, preds, dataset.get_weight()) grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg]
elif argc == 4: elif argc == 4:
grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg]
else: else:
raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")
return grad, hess return grad, hess
...@@ -213,11 +231,11 @@ class _EvalFunctionWrapper: ...@@ -213,11 +231,11 @@ class _EvalFunctionWrapper:
labels = dataset.get_label() labels = dataset.get_label()
argc = len(signature(self.func).parameters) argc = len(signature(self.func).parameters)
if argc == 2: if argc == 2:
return self.func(labels, preds) return self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3: elif argc == 3:
return self.func(labels, preds, dataset.get_weight()) return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg]
elif argc == 4: elif argc == 4:
return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg]
else: else:
raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}")
...@@ -819,7 +837,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -819,7 +837,7 @@ class LGBMModel(_LGBMModelBase):
num_boost_round=self.n_estimators, num_boost_round=self.n_estimators,
valid_sets=valid_sets, valid_sets=valid_sets,
valid_names=eval_names, valid_names=eval_names,
feval=eval_metrics_callable, feval=eval_metrics_callable, # type: ignore[arg-type]
init_model=init_model, init_model=init_model,
feature_name=feature_name, feature_name=feature_name,
callbacks=callbacks callbacks=callbacks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment