Unverified Commit 4971a066 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python] add more type hints on LGBMModel methods (#5239)

parent 9893867c
...@@ -505,7 +505,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -505,7 +505,7 @@ class LGBMModel(_LGBMModelBase):
self._n_classes = None self._n_classes = None
self.set_params(**kwargs) self.set_params(**kwargs)
def _more_tags(self): def _more_tags(self) -> Dict[str, Any]:
return { return {
'allow_nan': True, 'allow_nan': True,
'X_types': ['2darray', 'sparse', '1dlabels'], 'X_types': ['2darray', 'sparse', '1dlabels'],
...@@ -520,7 +520,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -520,7 +520,7 @@ class LGBMModel(_LGBMModelBase):
def __sklearn_is_fitted__(self) -> bool: def __sklearn_is_fitted__(self) -> bool:
return getattr(self, "fitted_", False) return getattr(self, "fitted_", False)
def get_params(self, deep=True): def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Get parameters for this estimator. """Get parameters for this estimator.
Parameters Parameters
...@@ -538,7 +538,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -538,7 +538,7 @@ class LGBMModel(_LGBMModelBase):
params.update(self._other_params) params.update(self._other_params)
return params return params
def set_params(self, **params): def set_params(self, **params: Any) -> "LGBMModel":
"""Set the parameters of this estimator. """Set the parameters of this estimator.
Parameters Parameters
...@@ -823,14 +823,14 @@ class LGBMModel(_LGBMModelBase): ...@@ -823,14 +823,14 @@ class LGBMModel(_LGBMModelBase):
) )
@property @property
def n_features_(self): def n_features_(self) -> int:
""":obj:`int`: The number of features of fitted model.""" """:obj:`int`: The number of features of fitted model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_features found. Need to call fit beforehand.') raise LGBMNotFittedError('No n_features found. Need to call fit beforehand.')
return self._n_features return self._n_features
@property @property
def n_features_in_(self): def n_features_in_(self) -> int:
""":obj:`int`: The number of features of fitted model.""" """:obj:`int`: The number of features of fitted model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_features_in found. Need to call fit beforehand.') raise LGBMNotFittedError('No n_features_in found. Need to call fit beforehand.')
...@@ -844,14 +844,14 @@ class LGBMModel(_LGBMModelBase): ...@@ -844,14 +844,14 @@ class LGBMModel(_LGBMModelBase):
return self._best_score return self._best_score
@property @property
def best_iteration_(self): def best_iteration_(self) -> int:
""":obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified.""" """:obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.') raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.')
return self._best_iteration return self._best_iteration
@property @property
def objective_(self): def objective_(self) -> Union[str, _LGBM_ScikitCustomObjectiveFunction]:
""":obj:`str` or :obj:`callable`: The concrete objective used while fitting this model.""" """:obj:`str` or :obj:`callable`: The concrete objective used while fitting this model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No objective found. Need to call fit beforehand.') raise LGBMNotFittedError('No objective found. Need to call fit beforehand.')
...@@ -1088,7 +1088,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1088,7 +1088,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
return self._classes return self._classes
@property @property
def n_classes_(self): def n_classes_(self) -> int:
""":obj:`int`: The number of classes.""" """:obj:`int`: The number of classes."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No classes found. Need to call fit beforehand.') raise LGBMNotFittedError('No classes found. Need to call fit beforehand.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment