"include/vscode:/vscode.git/clone" did not exist on "64db95788b0064be25cf240d399224d953bf1e39"
Unverified Commit 2f5baa3d authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on `cv()` (#5271)



* [python-package] add type hints on cv()

* remove inadvertent changes

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent f3ea1ad7
......@@ -75,9 +75,9 @@ try:
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
try:
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
except ImportError:
from sklearn.cross_validation import GroupKFold, StratifiedKFold
from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import NotFittedError
try:
from sklearn.utils.validation import _check_sample_weight
......@@ -90,6 +90,7 @@ try:
return sample_weight
SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
_LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin
......
......@@ -4,19 +4,24 @@ import collections
import copy
from operator import attrgetter
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from . import callback
from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, _log_warning
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
_LGBM_CustomMetricFunction = Callable[
[np.ndarray, Dataset],
Tuple[str, float, bool]
]
_LGBM_PreprocFunction = Callable[
[Dataset, Dataset, Dict[str, Any]],
Tuple[Dataset, Dataset, Dict[str, Any]]
]
def train(
params: Dict[str, Any],
......@@ -373,12 +378,25 @@ def _agg_cv_result(raw_results):
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
def cv(params, train_set, num_boost_round=100,
folds=None, nfold=5, stratified=True, shuffle=True,
metrics=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
fpreproc=None, seed=0, callbacks=None, eval_train_metric=False,
return_cvbooster=False):
def cv(
params: Dict[str, Any],
train_set: Dataset,
num_boost_round: int = 100,
folds: Optional[Union[Iterable[Tuple[np.ndarray, np.ndarray]], _LGBMBaseCrossValidator]] = None,
nfold: int = 5,
stratified: bool = True,
shuffle: bool = True,
metrics: Optional[Union[str, List[str]]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[str, List[str]] = 'auto',
categorical_feature: Union[str, List[str], List[int]] = 'auto',
fpreproc: Optional[_LGBM_PreprocFunction] = None,
seed: int = 0,
callbacks: Optional[List[Callable]] = None,
eval_train_metric: bool = False,
return_cvbooster: bool = False
) -> Dict[str, Any]:
"""Perform the cross-validation with given parameters.
Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment