Unverified Commit 74dfd905 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add more type hints in sklearn.py (#5710)

parent c676a7ea
......@@ -31,6 +31,7 @@ __all__ = [
_DatasetHandle = ctypes.c_void_p
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_LabelType = Union[
list,
......@@ -2857,7 +2858,7 @@ class Booster:
self._train_data_name = "training"
self.__set_objective_to_none = False
self.best_iteration = -1
self.best_score = {}
self.best_score: _LGBM_BoosterBestScoreType = {}
params = {} if params is None else deepcopy(params)
if train_set is not None:
# Training task
......@@ -4178,7 +4179,7 @@ class Booster:
result_array_like : numpy array or pandas DataFrame (if pandas is installed)
If ``xgboost_style=True``, the histogram of used splitting values for the specified feature.
"""
def add(root):
def add(root: Dict[str, Any]) -> None:
"""Recursively add thresholds."""
if 'split_index' in root: # non-leaf
if feature_names is not None and isinstance(feature, str):
......
......@@ -13,6 +13,7 @@ __all__ = [
'reset_parameter',
]
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
_EvalResultTuple = Union[
List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]]
......@@ -106,7 +107,7 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCal
class _RecordEvaluationCallback:
"""Internal record evaluation callable class."""
def __init__(self, eval_result: Dict[str, Dict[str, List[Any]]]) -> None:
def __init__(self, eval_result: _EvalResultDict) -> None:
self.order = 20
self.before_iteration = False
......
......@@ -7,9 +7,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _LGBM_EvalFunctionResultType,
_log_warning)
from .callback import record_evaluation
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _LGBM_BoosterBestScoreType,
_LGBM_EvalFunctionResultType, _log_warning)
from .callback import _EvalResultDict, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase,
......@@ -519,18 +519,18 @@ class LGBMModel(_LGBMModelBase):
self.n_jobs = n_jobs
self.importance_type = importance_type
self._Booster: Optional[Booster] = None
self._evals_result = None
self._best_score = None
self._best_iteration = None
self._evals_result: _EvalResultDict = {}
self._best_score: _LGBM_BoosterBestScoreType = {}
self._best_iteration: Optional[int] = None
self._other_params: Dict[str, Any] = {}
self._objective = objective
self.class_weight = class_weight
self._class_weight = None
self._class_map = None
self._class_weight: Optional[Union[Dict, str]] = None
self._class_map: Optional[Dict[int, int]] = None
self._n_features = None
self._n_features_in = None
self._classes = None
self._n_classes = None
self._n_classes: Optional[int] = None
self.set_params(**kwargs)
def _more_tags(self) -> Dict[str, Any]:
......@@ -797,7 +797,7 @@ class LGBMModel(_LGBMModelBase):
else:
callbacks = copy.copy(callbacks) # don't use deepcopy here to allow non-serializable objects
evals_result = {}
evals_result: _EvalResultDict = {}
callbacks.append(record_evaluation(evals_result))
self._Booster = train(
......@@ -904,7 +904,7 @@ class LGBMModel(_LGBMModelBase):
return self._n_features_in
@property
def best_score_(self):
def best_score_(self) -> _LGBM_BoosterBestScoreType:
""":obj:`dict`: The best score of fitted model."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No best_score found. Need to call fit beforehand.')
......@@ -954,7 +954,7 @@ class LGBMModel(_LGBMModelBase):
return self._Booster
@property
def evals_result_(self):
def evals_result_(self) -> _EvalResultDict:
""":obj:`dict`: The evaluation results if validation sets have been specified."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No results found. Need to call fit with eval_set beforehand.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment