"tests/vscode:/vscode.git/clone" did not exist on "cba824474897c8d7e71a3df261faaefa091a40c5"
Unverified Commit f975d3fa authored by IdoKendo's avatar IdoKendo Committed by GitHub
Browse files

[python-package] [dask] fix mypy errors regarding predict_proba (#5728)

parent f136de41
......@@ -1255,13 +1255,29 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]"
)
def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
def predict_proba(
self,
X: _DaskMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
pred_leaf: bool = False,
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any
) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(
model=self.to_local(),
data=X,
pred_proba=True,
client=_get_dask_client(self.client),
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment