Unverified Commit c3ac77b5 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] remove 'client' kwarg from fit() and predict() (fixes #3808) (#3883)



* starting on Dask client

* more docs stuff

* fix pickling

* just copy docstrings

* fit docs

* switch test order

* linting

* use client kwarg

* remove inner set_params()

* add type hints

* fix type hints

* remove commented code

* reorder

* fix tests, add client_ property

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* fix tests

* linting

* simplify
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 56fc036d
...@@ -95,7 +95,7 @@ if [[ $TASK == "swig" ]]; then ...@@ -95,7 +95,7 @@ if [[ $TASK == "swig" ]]; then
exit 0 exit 0
fi fi
conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
# graphviz must come from conda-forge to avoid this on some linux distros: # graphviz must come from conda-forge to avoid this on some linux distros:
# https://github.com/conda-forge/graphviz-feedstock/issues/18 # https://github.com/conda-forge/graphviz-feedstock/issues/18
......
...@@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost. ...@@ -9,7 +9,7 @@ It is based on dask-lightgbm, which was based on dask-xgboost.
import socket import socket
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Type, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
...@@ -17,7 +17,7 @@ import scipy.sparse as ss ...@@ -17,7 +17,7 @@ import scipy.sparse as ss
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
SKLEARN_INSTALLED, SKLEARN_INSTALLED, LGBMNotFittedError,
DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait) DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker
...@@ -27,6 +27,25 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix] ...@@ -27,6 +27,25 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
def _get_dask_client(client: Optional[Client]) -> Client:
"""Choose a Dask client to use.
Parameters
----------
client : dask.distributed.Client or None
Dask client.
Returns
-------
client : dask.distributed.Client
A Dask client.
"""
if client is None:
return default_client()
else:
return client
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port. """Find an open port.
...@@ -434,6 +453,29 @@ def _predict( ...@@ -434,6 +453,29 @@ def _predict(
class _DaskLGBMModel: class _DaskLGBMModel:
@property
def client_(self) -> Client:
"""Dask client.
This property can be passed in the constructor or updated
with ``model.set_params(client=client)``.
"""
if not getattr(self, "fitted_", False):
raise LGBMNotFittedError('Cannot access property client_ before calling fit().')
return _get_dask_client(client=self.client)
def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None)
out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None})
self._client = client
self.client = client
return out
def _fit( def _fit(
self, self,
model_factory: Type[LGBMModel], model_factory: Type[LGBMModel],
...@@ -441,18 +483,16 @@ class _DaskLGBMModel: ...@@ -441,18 +483,16 @@ class _DaskLGBMModel:
y: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any **kwargs: Any
) -> "_DaskLGBMModel": ) -> "_DaskLGBMModel":
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if client is None:
client = default_client()
params = self.get_params(True) params = self.get_params(True)
params.pop("client", None)
model = _train( model = _train(
client=client, client=_get_dask_client(self.client),
data=X, data=X,
label=y, label=y,
params=params, params=params,
...@@ -468,8 +508,11 @@ class _DaskLGBMModel: ...@@ -468,8 +508,11 @@ class _DaskLGBMModel:
return self return self
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
model = model_factory(**self.get_params()) params = self.get_params()
params.pop("client", None)
model = model_factory(**params)
self._copy_extra_params(self, model) self._copy_extra_params(self, model)
model._other_params.pop("client", None)
return model return model
@staticmethod @staticmethod
...@@ -478,18 +521,82 @@ class _DaskLGBMModel: ...@@ -478,18 +521,82 @@ class _DaskLGBMModel:
attributes = source.__dict__ attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys()) extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names: for name in extra_param_names:
if name != "_client":
setattr(dest, name, attributes[name]) setattr(dest, name, attributes[name])
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier.""" """Distributed version of lightgbm.LGBMClassifier."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit( def fit(
self, self,
X: _DaskMatrixLike, X: _DaskMatrixLike,
y: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMClassifier": ) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
...@@ -498,16 +605,10 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -498,16 +605,10 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
X=X, X=X,
y=y, y=y,
sample_weight=sample_weight, sample_weight=sample_weight,
client=client,
**kwargs **kwargs
) )
_base_doc = LGBMClassifier.fit.__doc__ fit.__doc__ = LGBMClassifier.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" """Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
...@@ -545,6 +646,70 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -545,6 +646,70 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRegressor.""" """Distributed version of lightgbm.LGBMRegressor."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit( def fit(
self, self,
X: _DaskMatrixLike, X: _DaskMatrixLike,
...@@ -559,16 +724,10 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -559,16 +724,10 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
X=X, X=X,
y=y, y=y,
sample_weight=sample_weight, sample_weight=sample_weight,
client=client,
**kwargs **kwargs
) )
_base_doc = LGBMRegressor.fit.__doc__ fit.__doc__ = LGBMRegressor.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" """Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
...@@ -594,6 +753,70 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -594,6 +753,70 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRanker.""" """Distributed version of lightgbm.LGBMRanker."""
def __init__(
self,
boosting_type: str = 'gbdt',
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.,
subsample_freq: int = 0,
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: int = -1,
silent: bool = True,
importance_type: str = 'split',
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
silent=silent,
importance_type=importance_type,
**kwargs
)
_base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
__init__.__doc__ = (
_before_kwargs
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
+ ' ' * 8 + _kwargs + _after_kwargs
)
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
def fit( def fit(
self, self,
X: _DaskMatrixLike, X: _DaskMatrixLike,
...@@ -601,7 +824,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -601,7 +824,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None, init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any **kwargs: Any
) -> "DaskLGBMRanker": ) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit.""" """Docstring is inherited from the lightgbm.LGBMRanker.fit."""
...@@ -614,16 +836,10 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -614,16 +836,10 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
y=y, y=y,
sample_weight=sample_weight, sample_weight=sample_weight,
group=group, group=group,
client=client,
**kwargs **kwargs
) )
_base_doc = LGBMRanker.fit.__doc__ fit.__doc__ = LGBMRanker.fit.__doc__
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :')
fit.__doc__ = (_before_eval_set
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
......
# coding: utf-8 # coding: utf-8
"""Tests for lightgbm.dask module""" """Tests for lightgbm.dask module"""
import inspect
import joblib
import pickle
import socket import socket
from itertools import groupby from itertools import groupby
from os import getenv from os import getenv
...@@ -13,13 +16,14 @@ if not platform.startswith('linux'): ...@@ -13,13 +16,14 @@ if not platform.startswith('linux'):
if not lgb.compat.DASK_INSTALLED: if not lgb.compat.DASK_INSTALLED:
pytest.skip('Dask is not installed', allow_module_level=True) pytest.skip('Dask is not installed', allow_module_level=True)
import cloudpickle
import dask.array as da import dask.array as da
import dask.dataframe as dd import dask.dataframe as dd
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from scipy.stats import spearmanr from scipy.stats import spearmanr
from dask.array.utils import assert_eq from dask.array.utils import assert_eq
from dask.distributed import wait from dask.distributed import default_client, Client, LocalCluster, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from sklearn.datasets import make_blobs, make_regression from sklearn.datasets import make_blobs, make_regression
...@@ -137,6 +141,32 @@ def _accuracy_score(dy_true, dy_pred): ...@@ -137,6 +141,32 @@ def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute() return da.average(dy_true == dy_pred).compute()
def _pickle(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
pickle.dump(obj, f)
elif serializer == 'joblib':
joblib.dump(obj, filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'wb') as f:
cloudpickle.dump(obj, f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def _unpickle(filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'rb') as f:
return pickle.load(f)
elif serializer == 'joblib':
return joblib.load(filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'rb') as f:
return cloudpickle.load(f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers) @pytest.mark.parametrize('centers', data_centers)
def test_classifier(output, centers, client, listen_port): def test_classifier(output, centers, client, listen_port):
...@@ -151,11 +181,12 @@ def test_classifier(output, centers, client, listen_port): ...@@ -151,11 +181,12 @@ def test_classifier(output, centers, client, listen_port):
"num_leaves": 10 "num_leaves": 10
} }
dask_classifier = lgb.DaskLGBMClassifier( dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5, time_out=5,
local_listen_port=listen_port, local_listen_port=listen_port,
**params **params
) )
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
p1 = dask_classifier.predict(dX) p1 = dask_classifier.predict(dX)
p1_proba = dask_classifier.predict_proba(dX).compute() p1_proba = dask_classifier.predict_proba(dX).compute()
p1_local = dask_classifier.to_local().predict(X) p1_local = dask_classifier.to_local().predict(X)
...@@ -193,12 +224,13 @@ def test_classifier_pred_contrib(output, centers, client, listen_port): ...@@ -193,12 +224,13 @@ def test_classifier_pred_contrib(output, centers, client, listen_port):
"num_leaves": 10 "num_leaves": 10
} }
dask_classifier = lgb.DaskLGBMClassifier( dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5, time_out=5,
local_listen_port=listen_port, local_listen_port=listen_port,
tree_learner='data', tree_learner='data',
**params **params
) )
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute() preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute()
local_classifier = lgb.LGBMClassifier(**params) local_classifier = lgb.LGBMClassifier(**params)
...@@ -241,6 +273,7 @@ def test_training_does_not_fail_on_port_conflicts(client): ...@@ -241,6 +273,7 @@ def test_training_does_not_fail_on_port_conflicts(client):
s.bind(('127.0.0.1', 12400)) s.bind(('127.0.0.1', 12400))
dask_classifier = lgb.DaskLGBMClassifier( dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5, time_out=5,
local_listen_port=12400, local_listen_port=12400,
n_estimators=5, n_estimators=5,
...@@ -251,7 +284,6 @@ def test_training_does_not_fail_on_port_conflicts(client): ...@@ -251,7 +284,6 @@ def test_training_does_not_fail_on_port_conflicts(client):
X=dX, X=dX,
y=dy, y=dy,
sample_weight=dw, sample_weight=dw,
client=client
) )
assert dask_classifier.booster_ assert dask_classifier.booster_
...@@ -270,12 +302,13 @@ def test_regressor(output, client, listen_port): ...@@ -270,12 +302,13 @@ def test_regressor(output, client, listen_port):
"num_leaves": 10 "num_leaves": 10
} }
dask_regressor = lgb.DaskLGBMRegressor( dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5, time_out=5,
local_listen_port=listen_port, local_listen_port=listen_port,
tree='data', tree='data',
**params **params
) )
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
p1 = dask_regressor.predict(dX) p1 = dask_regressor.predict(dX)
if output != 'dataframe': if output != 'dataframe':
s1 = _r2_score(dy, p1) s1 = _r2_score(dy, p1)
...@@ -313,12 +346,13 @@ def test_regressor_pred_contrib(output, client, listen_port): ...@@ -313,12 +346,13 @@ def test_regressor_pred_contrib(output, client, listen_port):
"num_leaves": 10 "num_leaves": 10
} }
dask_regressor = lgb.DaskLGBMRegressor( dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5, time_out=5,
local_listen_port=listen_port, local_listen_port=listen_port,
tree_learner='data', tree_learner='data',
**params **params
) )
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute() preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute()
local_regressor = lgb.LGBMRegressor(**params) local_regressor = lgb.LGBMRegressor(**params)
...@@ -353,11 +387,12 @@ def test_regressor_quantile(output, client, listen_port, alpha): ...@@ -353,11 +387,12 @@ def test_regressor_quantile(output, client, listen_port, alpha):
"num_leaves": 10 "num_leaves": 10
} }
dask_regressor = lgb.DaskLGBMRegressor( dask_regressor = lgb.DaskLGBMRegressor(
client=client,
local_listen_port=listen_port, local_listen_port=listen_port,
tree_learner_type='data_parallel', tree_learner_type='data_parallel',
**params **params
) )
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute() p1 = dask_regressor.predict(dX).compute()
q1 = np.count_nonzero(y < p1) / y.shape[0] q1 = np.count_nonzero(y < p1) / y.shape[0]
...@@ -400,12 +435,13 @@ def test_ranker(output, client, listen_port, group): ...@@ -400,12 +435,13 @@ def test_ranker(output, client, listen_port, group):
"min_child_samples": 1 "min_child_samples": 1
} }
dask_ranker = lgb.DaskLGBMRanker( dask_ranker = lgb.DaskLGBMRanker(
client=client,
time_out=5, time_out=5,
local_listen_port=listen_port, local_listen_port=listen_port,
tree_learner_type='data_parallel', tree_learner_type='data_parallel',
**params **params
) )
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client) dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg)
rnkvec_dask = dask_ranker.predict(dX) rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute() rnkvec_dask = rnkvec_dask.compute()
rnkvec_dask_local = dask_ranker.to_local().predict(X) rnkvec_dask_local = dask_ranker.to_local().predict(X)
...@@ -424,6 +460,288 @@ def test_ranker(output, client, listen_port, group): ...@@ -424,6 +460,288 @@ def test_ranker(output, client, listen_port, group):
client.close(timeout=CLIENT_CLOSE_TIMEOUT) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
def test_training_works_if_client_not_provided_or_set_after_construction(task, listen_port, client):
if task == 'ranking':
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
output='array',
group=None
)
model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output='array',
)
dg = None
if task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor
params = {
"time_out": 5,
"local_listen_port": listen_port,
"n_estimators": 1,
"num_leaves": 2
}
# should be able to use the class without specifying a client
dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
local_model = dask_model.to_local()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
# should be able to set client after construction
dask_model = model_factory(**params)
dask_model.set_params(client=client)
assert dask_model._client == client
assert dask_model.client == client
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle'])
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
@pytest.mark.parametrize('set_client', [True, False])
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, tmp_path):
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1:
with Client(cluster1) as client1:
# data on cluster1
if task == 'ranking':
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data(
output='array',
group=None
)
else:
X_1, _, _, dX_1, dy_1, _ = _create_data(
objective=task,
output='array',
)
dg_1 = None
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2:
with Client(cluster2) as client2:
# create identical data on cluster2
if task == 'ranking':
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data(
output='array',
group=None
)
else:
X_2, _, _, dX_2, dy_2, _ = _create_data(
objective=task,
output='array',
)
dg_2 = None
if task == 'ranking':
model_factory = lgb.DaskLGBMRanker
elif task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor
params = {
"time_out": 5,
"local_listen_port": listen_port,
"n_estimators": 1,
"num_leaves": 2
}
# at this point, the result of default_client() is client2 since it was the most recently
# created. So setting client to client1 here to test that you can select a non-default client
assert default_client() == client2
if set_client:
params.update({"client": client1})
# unfitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
assert "client" not in local_model.get_params()
assert getattr(local_model, "client", None) is None
tmp_file = str(tmp_path / "model-1.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file,
serializer=serializer
)
model_from_disk = _unpickle(
filepath=tmp_file,
serializer=serializer
)
local_tmp_file = str(tmp_path / "local-model-1.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file,
serializer=serializer
)
local_model_from_disk = _unpickle(
filepath=local_tmp_file,
serializer=serializer
)
assert model_from_disk._client is None
assert model_from_disk.client is None
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
# client will always be None after unpickling
if set_client:
from_disk_params = model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert model_from_disk.get_params() == dask_model.get_params()
assert local_model_from_disk.get_params() == local_model.get_params()
# fitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
if set_client:
dask_model.fit(dX_1, dy_1, group=dg_1)
else:
dask_model.fit(dX_2, dy_2, group=dg_2)
local_model = dask_model.to_local()
assert "client" not in local_model.get_params()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
tmp_file2 = str(tmp_path / "model-2.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file2,
serializer=serializer
)
fitted_model_from_disk = _unpickle(
filepath=tmp_file2,
serializer=serializer
)
local_tmp_file2 = str(tmp_path / "local-model-2.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file2,
serializer=serializer
)
local_fitted_model_from_disk = _unpickle(
filepath=local_tmp_file2,
serializer=serializer
)
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk._client is None
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2
# client will always be None after unpickling
if set_client:
from_disk_params = fitted_model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert fitted_model_from_disk.get_params() == dask_model.get_params()
assert local_fitted_model_from_disk.get_params() == local_model.get_params()
if set_client:
preds_orig = dask_model.predict(dX_1).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute()
preds_orig_local = local_model.predict(X_1)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1)
else:
preds_orig = dask_model.predict(dX_2).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute()
preds_orig_local = local_model.predict(X_2)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2)
assert_eq(preds_orig, preds_loaded_model)
assert_eq(preds_orig_local, preds_loaded_model_local)
def test_find_open_port_works(): def test_find_open_port_works():
worker_ip = '127.0.0.1' worker_ip = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
...@@ -451,6 +769,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client): ...@@ -451,6 +769,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
X = da.random.random((1e3, 10)) X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1)) y = da.random.random((1e3, 1))
dask_regressor = lgb.DaskLGBMRegressor( dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5, time_out=5,
local_listen_port=1234, local_listen_port=1234,
tree_learner='some-nonsense-value', tree_learner='some-nonsense-value',
...@@ -458,7 +777,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client): ...@@ -458,7 +777,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
num_leaves=2 num_leaves=2
) )
with pytest.warns(UserWarning, match='Parameter tree_learner set to some-nonsense-value'): with pytest.warns(UserWarning, match='Parameter tree_learner set to some-nonsense-value'):
dask_regressor = dask_regressor.fit(X, y, client=client) dask_regressor = dask_regressor.fit(X, y)
assert dask_regressor.fitted_ assert dask_regressor.fitted_
...@@ -470,6 +789,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client): ...@@ -470,6 +789,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
y = da.random.random((1e3, 1)) y = da.random.random((1e3, 1))
for tree_learner in ['feature_parallel', 'voting']: for tree_learner in ['feature_parallel', 'voting']:
dask_regressor = lgb.DaskLGBMRegressor( dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5, time_out=5,
local_listen_port=1234, local_listen_port=1234,
tree_learner=tree_learner, tree_learner=tree_learner,
...@@ -477,7 +797,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client): ...@@ -477,7 +797,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
num_leaves=2 num_leaves=2
) )
with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner): with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner):
dask_regressor = dask_regressor.fit(X, y, client=client) dask_regressor = dask_regressor.fit(X, y)
assert dask_regressor.fitted_ assert dask_regressor.fitted_
assert dask_regressor.get_params()['tree_learner'] == tree_learner assert dask_regressor.get_params()['tree_learner'] == tree_learner
...@@ -501,3 +821,26 @@ def test_errors(c, s, a, b): ...@@ -501,3 +821,26 @@ def test_errors(c, s, a, b):
model_factory=lgb.LGBMClassifier model_factory=lgb.LGBMClassifier
) )
assert 'foo' in str(info.value) assert 'foo' in str(info.value)
@pytest.mark.parametrize(
"classes",
[
(lgb.DaskLGBMClassifier, lgb.LGBMClassifier),
(lgb.DaskLGBMRegressor, lgb.LGBMRegressor),
(lgb.DaskLGBMRanker, lgb.LGBMRanker)
]
)
def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes):
dask_spec = inspect.getfullargspec(classes[0])
sklearn_spec = inspect.getfullargspec(classes[1])
assert dask_spec.varargs == sklearn_spec.varargs
assert dask_spec.varkw == sklearn_spec.varkw
assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs
assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults
# "client" should be the only different, and the final argument
assert dask_spec.args[:-1] == sklearn_spec.args
assert dask_spec.defaults[:-1] == sklearn_spec.defaults
assert dask_spec.args[-1] == 'client'
assert dask_spec.defaults[-1] is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment