Unverified Commit 1e92ec9e authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on predict() methods (#5334)

* [python-package] add type hints on predict() methods

* formatting
parent df14e607
......@@ -756,8 +756,17 @@ class _InnerPredictor:
this.pop('handle', None)
return this
def predict(self, data, start_iteration=0, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, validate_features=False):
def predict(
self,
data,
start_iteration: int = 0,
num_iteration: int = -1,
raw_score: bool = False,
pred_leaf: bool = False,
pred_contrib: bool = False,
data_has_header: bool = False,
validate_features: bool = False
):
"""Predict logic.
Parameters
......@@ -3513,10 +3522,18 @@ class Booster:
default=json_default_with_numpy))
return ret
def predict(self, data, start_iteration=0, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, validate_features=False,
**kwargs):
def predict(
self,
data,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
raw_score: bool = False,
pred_leaf: bool = False,
pred_contrib: bool = False,
data_has_header: bool = False,
validate_features: bool = False,
**kwargs: Any
):
"""Make a prediction.
Parameters
......
......@@ -826,8 +826,17 @@ class LGBMModel(_LGBMModelBase):
eval_group_shape="list of array, or None, optional (default=None)"
) + "\n\n" + _lgbmmodel_doc_custom_eval_note
def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs):
def predict(
self,
X,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
):
"""Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
......@@ -1094,9 +1103,17 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')]
+ _base_doc[_base_doc.find('eval_metric :'):])
def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, validate_features=False,
**kwargs):
def predict(
self,
X,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, validate_features,
......@@ -1109,8 +1126,17 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
predict.__doc__ = LGBMModel.predict.__doc__
def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs):
def predict_proba(
self,
X,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
):
"""Docstring is set after definition, using a template."""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment