Unverified Commit a5285985 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors about custom eval and metric functions (#5790)

parent 9f035100
......@@ -11,7 +11,7 @@ import numpy as np
from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction,
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType,
_LGBM_FeatureNameConfiguration, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
......@@ -22,9 +22,15 @@ __all__ = [
]
_LGBM_CustomMetricFunction = Callable[
_LGBM_CustomMetricFunction = Union[
Callable[
[np.ndarray, Dataset],
Union[Tuple[str, float, bool], List[Tuple[str, float, bool]]]
_LGBM_EvalFunctionResultType,
],
Callable[
[np.ndarray, Dataset],
List[_LGBM_EvalFunctionResultType]
],
]
_LGBM_PreprocFunction = Callable[
......
......@@ -33,32 +33,50 @@ _LGBM_ScikitMatrixLike = Union[
scipy.sparse.spmatrix
]
_LGBM_ScikitCustomObjectiveFunction = Union[
# f(labels, preds)
Callable[
[np.ndarray, np.ndarray],
[Optional[np.ndarray], np.ndarray],
Tuple[np.ndarray, np.ndarray]
],
# f(labels, preds, weights)
Callable[
[np.ndarray, np.ndarray, np.ndarray],
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
Tuple[np.ndarray, np.ndarray]
],
# f(labels, preds, weights, group)
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
Tuple[np.ndarray, np.ndarray]
],
]
_LGBM_ScikitCustomEvalFunction = Union[
# f(labels, preds)
Callable[
[np.ndarray, np.ndarray],
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
[Optional[np.ndarray], np.ndarray],
_LGBM_EvalFunctionResultType
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
[Optional[np.ndarray], np.ndarray],
List[_LGBM_EvalFunctionResultType]
],
# f(labels, preds, weights)
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
_LGBM_EvalFunctionResultType
],
Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
List[_LGBM_EvalFunctionResultType]
],
# f(labels, preds, weights, group)
Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
_LGBM_EvalFunctionResultType
],
Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
List[_LGBM_EvalFunctionResultType]
]
]
_LGBM_ScikitEvalMetricType = Union[
str,
......@@ -135,11 +153,11 @@ class _ObjectiveFunctionWrapper:
labels = dataset.get_label()
argc = len(signature(self.func).parameters)
if argc == 2:
grad, hess = self.func(labels, preds)
grad, hess = self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3:
grad, hess = self.func(labels, preds, dataset.get_weight())
grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg]
elif argc == 4:
grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group())
grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg]
else:
raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")
return grad, hess
......@@ -213,11 +231,11 @@ class _EvalFunctionWrapper:
labels = dataset.get_label()
argc = len(signature(self.func).parameters)
if argc == 2:
return self.func(labels, preds)
return self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3:
return self.func(labels, preds, dataset.get_weight())
return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg]
elif argc == 4:
return self.func(labels, preds, dataset.get_weight(), dataset.get_group())
return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg]
else:
raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}")
......@@ -819,7 +837,7 @@ class LGBMModel(_LGBMModelBase):
num_boost_round=self.n_estimators,
valid_sets=valid_sets,
valid_names=eval_names,
feval=eval_metrics_callable,
feval=eval_metrics_callable, # type: ignore[arg-type]
init_model=init_model,
feature_name=feature_name,
callbacks=callbacks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment