Unverified Commit 2f5baa3d authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on `cv()` (#5271)



* [python-package] add type hints on cv()

* remove inadvertent changes

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent f3ea1ad7
...@@ -75,9 +75,9 @@ try: ...@@ -75,9 +75,9 @@ try:
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
try: try:
from sklearn.exceptions import NotFittedError from sklearn.exceptions import NotFittedError
from sklearn.model_selection import GroupKFold, StratifiedKFold from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
except ImportError: except ImportError:
from sklearn.cross_validation import GroupKFold, StratifiedKFold from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import NotFittedError from sklearn.utils.validation import NotFittedError
try: try:
from sklearn.utils.validation import _check_sample_weight from sklearn.utils.validation import _check_sample_weight
...@@ -90,6 +90,7 @@ try: ...@@ -90,6 +90,7 @@ try:
return sample_weight return sample_weight
SKLEARN_INSTALLED = True SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
_LGBMModelBase = BaseEstimator _LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin _LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin _LGBMClassifierBase = ClassifierMixin
......
...@@ -4,19 +4,24 @@ import collections ...@@ -4,19 +4,24 @@ import collections
import copy import copy
from operator import attrgetter from operator import attrgetter
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from . import callback from . import callback
from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, _log_warning from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, _log_warning
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
_LGBM_CustomMetricFunction = Callable[ _LGBM_CustomMetricFunction = Callable[
[np.ndarray, Dataset], [np.ndarray, Dataset],
Tuple[str, float, bool] Tuple[str, float, bool]
] ]
_LGBM_PreprocFunction = Callable[
[Dataset, Dataset, Dict[str, Any]],
Tuple[Dataset, Dataset, Dict[str, Any]]
]
def train( def train(
params: Dict[str, Any], params: Dict[str, Any],
...@@ -373,12 +378,25 @@ def _agg_cv_result(raw_results): ...@@ -373,12 +378,25 @@ def _agg_cv_result(raw_results):
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()] return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
def cv(params, train_set, num_boost_round=100, def cv(
folds=None, nfold=5, stratified=True, shuffle=True, params: Dict[str, Any],
metrics=None, feval=None, init_model=None, train_set: Dataset,
feature_name='auto', categorical_feature='auto', num_boost_round: int = 100,
fpreproc=None, seed=0, callbacks=None, eval_train_metric=False, folds: Optional[Union[Iterable[Tuple[np.ndarray, np.ndarray]], _LGBMBaseCrossValidator]] = None,
return_cvbooster=False): nfold: int = 5,
stratified: bool = True,
shuffle: bool = True,
metrics: Optional[Union[str, List[str]]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[str, List[str]] = 'auto',
categorical_feature: Union[str, List[str], List[int]] = 'auto',
fpreproc: Optional[_LGBM_PreprocFunction] = None,
seed: int = 0,
callbacks: Optional[List[Callable]] = None,
eval_train_metric: bool = False,
return_cvbooster: bool = False
) -> Dict[str, Any]:
"""Perform the cross-validation with given parameters. """Perform the cross-validation with given parameters.
Parameters Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment