Unverified Commit 843d380d authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python] add type hints for custom objective and metric functions in scikit-learn interface (#4547)



* [python] add type hints for custom objective and metric functions in scikit-learn interface

* update type hints

* remote unnecessary input

* Update python-package/lightgbm/sklearn.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* remove type hint on objective being callable
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent bfb346c1
...@@ -11,7 +11,7 @@ from collections import defaultdict, namedtuple ...@@ -11,7 +11,7 @@ from collections import defaultdict, namedtuple
from copy import deepcopy from copy import deepcopy
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
...@@ -21,8 +21,8 @@ from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _lo ...@@ -21,8 +21,8 @@ from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _lo
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat, from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series, dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait) default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note, from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict) _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series] _DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame] _DaskMatrixLike = Union[dask_Array, dask_DataFrame]
...@@ -400,7 +400,7 @@ def _train( ...@@ -400,7 +400,7 @@ def _train(
eval_class_weight: Optional[List[Union[dict, str]]] = None, eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None, eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None, eval_at: Optional[Iterable[int]] = None,
**kwargs: Any **kwargs: Any
) -> LGBMModel: ) -> LGBMModel:
...@@ -1029,7 +1029,7 @@ class _DaskLGBMModel: ...@@ -1029,7 +1029,7 @@ class _DaskLGBMModel:
eval_class_weight: Optional[List[Union[dict, str]]] = None, eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None, eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None, eval_at: Optional[Iterable[int]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
...@@ -1096,7 +1096,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1096,7 +1096,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None, objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
...@@ -1165,7 +1165,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1165,7 +1165,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
eval_sample_weight: Optional[List[_DaskVectorLike]] = None, eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None, eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None, eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMClassifier": ) -> "DaskLGBMClassifier":
...@@ -1281,7 +1281,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1281,7 +1281,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None, objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
...@@ -1348,7 +1348,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1348,7 +1348,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
eval_names: Optional[List[str]] = None, eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskVectorLike]] = None, eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None, eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMRegressor": ) -> "DaskLGBMRegressor":
...@@ -1446,7 +1446,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1446,7 +1446,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None, objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
...@@ -1516,7 +1516,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1516,7 +1516,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
eval_sample_weight: Optional[List[_DaskVectorLike]] = None, eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None, eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5), eval_at: Iterable[int] = (1, 2, 3, 4, 5),
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
**kwargs: Any **kwargs: Any
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Scikit-learn wrapper interface for LightGBM.""" """Scikit-learn wrapper interface for LightGBM."""
import copy import copy
from inspect import signature from inspect import signature
from typing import Callable, Dict, Optional, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -11,14 +11,42 @@ from .callback import log_evaluation, record_evaluation ...@@ -11,14 +11,42 @@ from .callback import log_evaluation, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray, from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable, _LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
pd_DataFrame) pd_DataFrame, pd_Series)
from .engine import train from .engine import train
_ArrayLike = Union[List, np.ndarray, pd_Series]
_EvalResultType = Tuple[str, float, bool]
_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
]
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
]
class _ObjectiveFunctionWrapper: class _ObjectiveFunctionWrapper:
"""Proxy class for objective function.""" """Proxy class for objective function."""
def __init__(self, func): def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
"""Construct a proxy class. """Construct a proxy class.
This class transforms objective function to match objective function with signature ``new_func(preds, dataset)`` This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
...@@ -107,7 +135,7 @@ class _ObjectiveFunctionWrapper: ...@@ -107,7 +135,7 @@ class _ObjectiveFunctionWrapper:
class _EvalFunctionWrapper: class _EvalFunctionWrapper:
"""Proxy class for evaluation function.""" """Proxy class for evaluation function."""
def __init__(self, func): def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
"""Construct a proxy class. """Construct a proxy class.
This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)`` This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
...@@ -358,7 +386,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -358,7 +386,7 @@ class LGBMModel(_LGBMModelBase):
learning_rate: float = 0.1, learning_rate: float = 0.1,
n_estimators: int = 100, n_estimators: int = 100,
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[str, Callable]] = None, objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None, class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment