"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "a39c848e6456d473d2043dff3f5159945a36b567"
Unverified Commit e6310868 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add more type hints in engine.py (#5301)



* [python-package] add more type hints in engine.py

* Update python-package/lightgbm/engine.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 1b43214f
...@@ -290,13 +290,13 @@ class CVBooster: ...@@ -290,13 +290,13 @@ class CVBooster:
self.boosters = [] self.boosters = []
self.best_iteration = -1 self.best_iteration = -1
def _append(self, booster): def _append(self, booster: Booster) -> None:
"""Add a booster to CVBooster.""" """Add a booster to CVBooster."""
self.boosters.append(booster) self.boosters.append(booster)
def __getattr__(self, name): def __getattr__(self, name: str) -> Callable[[Any, Any], List[Any]]:
"""Redirect methods call of CVBooster.""" """Redirect methods call of CVBooster."""
def handler_function(*args, **kwargs): def handler_function(*args: Any, **kwargs: Any) -> List[Any]:
"""Call methods with each booster, and concatenate their results.""" """Call methods with each booster, and concatenate their results."""
ret = [] ret = []
for booster in self.boosters: for booster in self.boosters:
...@@ -305,8 +305,17 @@ class CVBooster: ...@@ -305,8 +305,17 @@ class CVBooster:
return handler_function return handler_function
def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratified=True, def _make_n_folds(
shuffle=True, eval_train_metric=False): full_data: Dataset,
folds: Optional[Union[Iterable[Tuple[np.ndarray, np.ndarray]], _LGBMBaseCrossValidator]],
nfold: int,
params: Dict[str, Any],
seed: int,
fpreproc: Optional[_LGBM_PreprocFunction] = None,
stratified: bool = True,
shuffle: bool = True,
eval_train_metric: bool = False
) -> CVBooster:
"""Make a n-fold list of Booster from random indices.""" """Make a n-fold list of Booster from random indices."""
full_data = full_data.construct() full_data = full_data.construct()
num_data = full_data.num_data() num_data = full_data.num_data()
...@@ -365,7 +374,9 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi ...@@ -365,7 +374,9 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
return ret return ret
def _agg_cv_result(raw_results): def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]:
"""Aggregate cross-validation results.""" """Aggregate cross-validation results."""
cvmap = collections.OrderedDict() cvmap = collections.OrderedDict()
metric_type = {} metric_type = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment