Unverified Commit c3ac77b5 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) (#3883)



* starting on Dask client

* more docs stuff

* fix pickling

* just copy docstrings

* fit docs

* switch test order

* linting

* use client kwarg

* remove inner set_params()

* add type hints

* fix type hints

* remove commented code

* reorder

* fix tests, add client_ property

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* fix tests

* linting

* simplify
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 56fc036d
...@@ -95,7 +95,7 @@ if [[ $TASK == "swig" ]]; then ...@@ -95,7 +95,7 @@ if [[ $TASK == "swig" ]]; then
exit 0 exit 0
fi fi
conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
# graphviz must come from conda-forge to avoid this on some linux distros: # graphviz must come from conda-forge to avoid this on some linux distros:
# https://github.com/conda-forge/graphviz-feedstock/issues/18 # https://github.com/conda-forge/graphviz-feedstock/issues/18
......
...@@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost. ...@@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost.
import socket import socket
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Type, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
...@@ -17,7 +17,7 @@ import scipy.sparse as ss ...@@ -17,7 +17,7 @@ import scipy.sparse as ss
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
SKLEARN_INSTALLED, SKLEARN_INSTALLED, LGBMNotFittedError,
DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait) DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker
...@@ -27,6 +27,25 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix] ...@@ -27,6 +27,25 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
def _get_dask_client(client: Optional[Client]) -> Client:
"""Choose a Dask client to use.
Parameters
----------
client : dask.distributed.Client or None
Dask client.
Returns
-------
client : dask.distributed.Client
A Dask client.
"""
if client is None:
return default_client()
else:
return client
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port. """Find an open port.
...@@ -434,6 +453,29 @@ def _predict( ...@@ -434,6 +453,29 @@ def _predict(
class _DaskLGBMModel: class _DaskLGBMModel:
@property
def client_(self) -> Client:
"""Dask client.
This property can be passed in the constructor or updated
with ``model.set_params(client=client)``.
"""
if not getattr(self, "fitted_", False):
raise LGBMNotFittedError('Cannot access property client_ before calling fit().')
return _get_dask_client(client=self.client)
def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None)
out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None})
self._client = client
self.client = client
return out
def _fit( def _fit(
self, self,
model_factory: Type[LGBMModel], model_factory: Type[LGBMModel],
...@@ -441,18 +483,16 @@ class _DaskLGBMModel: ...@@ -441,18 +483,16 @@ class _DaskLGBMModel:
y: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any **kwargs: Any
) -> "_DaskLGBMModel": ) -> "_DaskLGBMModel":
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if client is None:
client = default_client()
params = self.get_params(True) params = self.get_params(True)
params.pop("client", None)
model = _train( model = _train(
client=client, client=_get_dask_client(self.client),
data=X, data=X,
label=y, label=y,
params=params, params=params,
...@@ -468,8 +508,11 @@ class _DaskLGBMModel: ...@@ -468,8 +508,11 @@ class _DaskLGBMModel:
return self return self
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
model = model_factory(**self.get_params()) params = self.get_params()
params.pop("client", None)
model = model_factory(**params)
self._copy_extra_params(self, model) self._copy_extra_params(self, model)
model._other_params.pop("client", None)
return model return model
@staticmethod @staticmethod
...@@ -478,18 +521,82 @@ class _DaskLGBMModel: ...@@ -478,18 +521,82 @@ class _DaskLGBMModel:
attributes = source.__dict__ attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys()) extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names: for name in extra_param_names:
if name != "_client":
setattr(dest, name, attributes[name]) setattr(dest, name, attributes[name])
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier.""" """Distributed version of lightgbm.LGBMClassifier."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit( def fit(
self, self,
X: _DaskMatrixLike, X: _DaskMatrixLike,
y: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMClassifier": ) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
...@@ -498,16 +605,10 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -498,16 +605,10 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
X=X, X=X,
y=y, y=y,
sample_weight=sample_weight, sample_weight=sample_weight,
client=client,
**kwargs **kwargs
) )
_base_doc = LGBMClassifier.fit.__doc__ fit.__doc__ = LGBMClassifier.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" """Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
...@@ -545,6 +646,70 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -545,6 +646,70 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRegressor.""" """Distributed version of lightgbm.LGBMRegressor."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit( def fit(
self, self,
X: _DaskMatrixLike, X: _DaskMatrixLike,
...@@ -559,16 +724,10 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -559,16 +724,10 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
X=X, X=X,
y=y, y=y,
sample_weight=sample_weight, sample_weight=sample_weight,
client=client,
**kwargs **kwargs
) )
_base_doc = LGBMRegressor.fit.__doc__ fit.__doc__ = LGBMRegressor.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" """Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
...@@ -594,6 +753,70 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -594,6 +753,70 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRanker.""" """Distributed version of lightgbm.LGBMRanker."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit( def fit(
self, self,
X: _DaskMatrixLike, X: _DaskMatrixLike,
...@@ -601,7 +824,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -601,7 +824,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None, init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMRanker": ) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit.""" """Docstring is inherited from the lightgbm.LGBMRanker.fit."""
...@@ -614,16 +836,10 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -614,16 +836,10 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
y=y, y=y,
sample_weight=sample_weight, sample_weight=sample_weight,
group=group, group=group,
client=client,
**kwargs **kwargs
) )
_base_doc = LGBMRanker.fit.__doc__ fit.__doc__ = LGBMRanker.fit.__doc__
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :')
fit.__doc__ = (_before_eval_set
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment