"python-package/vscode:/vscode.git/clone" did not exist on "95519f36d6f55ebd96c9da71dae1133ea99b1c9f"
Unverified Commit 7af85cee authored by IdoKendo's avatar IdoKendo Committed by GitHub
Browse files

[python-package] Fix mypy errors for predict() method (#5678)

parent 3c3f79e7
......@@ -1221,13 +1221,29 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
{_lgbmmodel_doc_custom_eval_note}
"""
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
def predict(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict(
model=self.to_local(),
data=X,
dtype=self.classes_.dtype,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
......@@ -1394,12 +1410,28 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
{_lgbmmodel_doc_custom_eval_note}
"""
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
def predict(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict(
model=self.to_local(),
data=X,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
......@@ -1552,12 +1584,28 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
{_lgbmmodel_doc_custom_eval_note}
"""
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
def predict(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(
model=self.to_local(),
data=X,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment