Unverified Commit f836fe0c authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[dask] determine output shape of array in predict (fixes #4285) (#4351)

* call predict on one row of data to determine output shape

* make DaskLGBMRanker predict method equal to the others

* remove extra drop_axis
parent 7f9959fe
...@@ -112,14 +112,7 @@ if [[ $TASK == "swig" ]]; then ...@@ -112,14 +112,7 @@ if [[ $TASK == "swig" ]]; then
exit 0 exit 0
fi fi
# temporary fix for https://github.com/microsoft/LightGBM/issues/4285 conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
if [[ $PYTHON_VERSION == "3.6" ]]; then
DASK_DEPENDENCIES="dask distributed"
else
DASK_DEPENDENCIES="dask=2021.4.0 distributed=2021.4.0"
fi
conda install -q -y -n $CONDA_ENV cloudpickle ${DASK_DEPENDENCIES} joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
pip install graphviz # python-graphviz from Anaconda is not allowed to be installed with Python 3.9 pip install graphviz # python-graphviz from Anaconda is not allowed to be installed with Python 3.9
if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then if [[ $OS_NAME == "macos" ]] && [[ $COMPILER == "clang" ]]; then
......
...@@ -838,6 +838,7 @@ def _predict_part( ...@@ -838,6 +838,7 @@ def _predict_part(
def _predict( def _predict(
model: LGBMModel, model: LGBMModel,
data: _DaskMatrixLike, data: _DaskMatrixLike,
client: Client,
raw_score: bool = False, raw_score: bool = False,
pred_proba: bool = False, pred_proba: bool = False,
pred_leaf: bool = False, pred_leaf: bool = False,
...@@ -956,16 +957,29 @@ def _predict( ...@@ -956,16 +957,29 @@ def _predict(
return out return out
return data.map_blocks( data_row = client.compute(data[[0]]).result()
predict_fn = partial(
_predict_part, _predict_part,
model=model, model=model,
raw_score=raw_score, raw_score=raw_score,
pred_proba=pred_proba, pred_proba=pred_proba,
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
**kwargs,
)
pred_row = predict_fn(data_row)
chunks = (data.chunks[0],)
map_blocks_kwargs = {}
if len(pred_row.shape) > 1:
chunks += (pred_row.shape[1],)
else:
map_blocks_kwargs['drop_axis'] = 1
return data.map_blocks(
predict_fn,
chunks=chunks,
meta=pred_row,
dtype=dtype, dtype=dtype,
drop_axis=1, **map_blocks_kwargs,
**kwargs
) )
else: else:
raise TypeError(f'Data must be either Dask Array or Dask DataFrame. Got {type(data).__name__}.') raise TypeError(f'Data must be either Dask Array or Dask DataFrame. Got {type(data).__name__}.')
...@@ -1201,6 +1215,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1201,6 +1215,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
model=self.to_local(), model=self.to_local(),
data=X, data=X,
dtype=self.classes_.dtype, dtype=self.classes_.dtype,
client=_get_dask_client(self.client),
**kwargs **kwargs
) )
...@@ -1219,6 +1234,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1219,6 +1234,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
model=self.to_local(), model=self.to_local(),
data=X, data=X,
pred_proba=True, pred_proba=True,
client=_get_dask_client(self.client),
**kwargs **kwargs
) )
...@@ -1378,6 +1394,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1378,6 +1394,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
return _predict( return _predict(
model=self.to_local(), model=self.to_local(),
data=X, data=X,
client=_get_dask_client(self.client),
**kwargs **kwargs
) )
...@@ -1537,7 +1554,12 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1537,7 +1554,12 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs) return _predict(
model=self.to_local(),
data=X,
client=_get_dask_client(self.client),
**kwargs
)
predict.__doc__ = _lgbmmodel_doc_predict.format( predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.", description="Return the predicted value for each sample.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment