Unverified Commit b8fc476e authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] [python-package] use keyword args for internal function calls (#3755)



* [dask] use keyword args for internal function calls

* add missing comma

* Update python-package/lightgbm/dask.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* revert whitespace changes

* test style
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 477cbf37
# coding: utf-8 # coding: utf-8
"""Distributed training with LightGBM and Dask.distributed. """Distributed training with LightGBM and Dask.distributed.
This module enables you to perform distributed training with LightGBM on Dask.Array and Dask.DataFrame collections. This module enables you to perform distributed training with LightGBM on
It is based on dask-xgboost package. Dask.Array and Dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
""" """
import logging import logging
import socket import socket
...@@ -145,10 +147,16 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re ...@@ -145,10 +147,16 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
if is_ranker: if is_ranker:
group = _concat(parts[-1]) group = _concat(parts[-1])
weight = _concat(parts[2]) if len(parts) == 4 else None if len(parts) == 4:
weight = _concat(parts[2])
else:
weight = None
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs) model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
else: else:
weight = _concat(parts[2]) if len(parts) == 3 else None if len(parts) == 3:
weight = _concat(parts[2])
else:
weight = None
model.fit(data, y=label, sample_weight=weight, **kwargs) model.fit(data, y=label, sample_weight=weight, **kwargs)
finally: finally:
...@@ -160,7 +168,10 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re ...@@ -160,7 +168,10 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
def _split_to_parts(data, is_matrix): def _split_to_parts(data, is_matrix):
parts = data.to_delayed() parts = data.to_delayed()
if isinstance(parts, np.ndarray): if isinstance(parts, np.ndarray):
assert (parts.shape[1] == 1) if is_matrix else (parts.ndim == 1 or parts.shape[1] == 1) if is_matrix:
assert parts.shape[1] == 1
else:
assert parts.ndim == 1 or parts.shape[1] == 1
parts = parts.flatten().tolist() parts = parts.flatten().tolist()
return parts return parts
...@@ -189,10 +200,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -189,10 +200,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
params = deepcopy(params) params = deepcopy(params)
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True) data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False) label_parts = _split_to_parts(data=label, is_matrix=False)
weight_parts = _split_to_parts(sample_weight, is_matrix=False) if sample_weight is not None else None
group_parts = _split_to_parts(group, is_matrix=False) if group is not None else None if sample_weight is not None:
weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
else:
weight_parts = None
if group is not None:
group_parts = _split_to_parts(data=group, is_matrix=False)
else:
group_parts = None
# choose between four options of (sample_weight, group) being (un)specified # choose between four options of (sample_weight, group) being (un)specified
if weight_parts is None and group_parts is None: if weight_parts is None and group_parts is None:
...@@ -265,15 +284,19 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -265,15 +284,19 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
params.pop(num_thread_alias, None) params.pop(num_thread_alias, None)
# Tell each worker to train on the parts that it has locally # Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part, futures_classifiers = [
model_factory=model_factory, client.submit(
params={**params, 'num_threads': worker_ncores[worker]}, _train_part,
list_of_parts=list_of_parts, model_factory=model_factory,
worker_address_to_port=worker_address_to_port, params={**params, 'num_threads': worker_ncores[worker]},
time_out=params.get('time_out', 120), list_of_parts=list_of_parts,
return_model=(worker == master_worker), worker_address_to_port=worker_address_to_port,
**kwargs) time_out=params.get('time_out', 120),
for worker, list_of_parts in worker_map.items()] return_model=(worker == master_worker),
**kwargs
)
for worker, list_of_parts in worker_map.items()
]
results = client.gather(futures_classifiers) results = client.gather(futures_classifiers)
results = [v for v in results if v] results = [v for v in results if v]
...@@ -368,8 +391,17 @@ class _LGBMModel: ...@@ -368,8 +391,17 @@ class _LGBMModel:
client = default_client() client = default_client()
params = self.get_params(True) params = self.get_params(True)
model = _train(client, data=X, label=y, params=params, model_factory=model_factory,
sample_weight=sample_weight, group=group, **kwargs) model = _train(
client=client,
data=X,
label=y,
params=params,
model_factory=model_factory,
sample_weight=sample_weight,
group=group,
**kwargs
)
self.set_params(**model.get_params()) self.set_params(**model.get_params())
self._copy_extra_params(model, self) self._copy_extra_params(model, self)
...@@ -395,17 +427,37 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier): ...@@ -395,17 +427,37 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(LGBMClassifier, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs) return self._fit(
model_factory=LGBMClassifier,
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)
fit.__doc__ = LGBMClassifier.fit.__doc__ fit.__doc__ = LGBMClassifier.fit.__doc__
def predict(self, X, **kwargs): def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" """Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict(self.to_local(), X, dtype=self.classes_.dtype, **kwargs) return _predict(
model=self.to_local(),
data=X,
dtype=self.classes_.dtype,
**kwargs
)
predict.__doc__ = LGBMClassifier.predict.__doc__ predict.__doc__ = LGBMClassifier.predict.__doc__
def predict_proba(self, X, **kwargs): def predict_proba(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(self.to_local(), X, pred_proba=True, **kwargs) return _predict(
model=self.to_local(),
data=X,
pred_proba=True,
**kwargs
)
predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__ predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__
def to_local(self): def to_local(self):
...@@ -423,12 +475,25 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor): ...@@ -423,12 +475,25 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" """Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(LGBMRegressor, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs) return self._fit(
model_factory=LGBMRegressor,
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)
fit.__doc__ = LGBMRegressor.fit.__doc__ fit.__doc__ = LGBMRegressor.fit.__doc__
def predict(self, X, **kwargs): def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" """Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict(self.to_local(), X, **kwargs) return _predict(
model=self.to_local(),
data=X,
**kwargs
)
predict.__doc__ = LGBMRegressor.predict.__doc__ predict.__doc__ = LGBMRegressor.predict.__doc__
def to_local(self): def to_local(self):
...@@ -449,12 +514,22 @@ class DaskLGBMRanker(_LGBMModel, LGBMRanker): ...@@ -449,12 +514,22 @@ class DaskLGBMRanker(_LGBMModel, LGBMRanker):
if init_score is not None: if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask') raise RuntimeError('init_score is not currently supported in lightgbm.dask')
return self._fit(LGBMRanker, X=X, y=y, sample_weight=sample_weight, group=group, client=client, **kwargs) return self._fit(
model_factory=LGBMRanker,
X=X,
y=y,
sample_weight=sample_weight,
group=group,
client=client,
**kwargs
)
fit.__doc__ = LGBMRanker.fit.__doc__ fit.__doc__ = LGBMRanker.fit.__doc__
def predict(self, X, **kwargs): def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs) return _predict(self.to_local(), X, **kwargs)
predict.__doc__ = LGBMRanker.predict.__doc__ predict.__doc__ = LGBMRanker.predict.__doc__
def to_local(self): def to_local(self):
......
...@@ -206,7 +206,11 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size ...@@ -206,7 +206,11 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers) @pytest.mark.parametrize('centers', data_centers)
def test_classifier(output, centers, client, listen_port): def test_classifier(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
)
dask_classifier = dlgbm.DaskLGBMClassifier( dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5, time_out=5,
...@@ -238,7 +242,11 @@ def test_classifier(output, centers, client, listen_port): ...@@ -238,7 +242,11 @@ def test_classifier(output, centers, client, listen_port):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers) @pytest.mark.parametrize('centers', data_centers)
def test_classifier_pred_contrib(output, centers, client, listen_port): def test_classifier_pred_contrib(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers) X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
)
dask_classifier = dlgbm.DaskLGBMClassifier( dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5, time_out=5,
...@@ -309,7 +317,10 @@ def test_training_does_not_fail_on_port_conflicts(client): ...@@ -309,7 +317,10 @@ def test_training_does_not_fail_on_port_conflicts(client):
def test_classifier_local_predict(client, listen_port): def test_classifier_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output='array') X, y, w, dX, dy, dw = _create_data(
objective='classification',
output='array'
)
dask_classifier = dlgbm.DaskLGBMClassifier( dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5, time_out=5,
...@@ -333,7 +344,10 @@ def test_classifier_local_predict(client, listen_port): ...@@ -333,7 +344,10 @@ def test_classifier_local_predict(client, listen_port):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_regressor(output, client, listen_port): def test_regressor(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output) X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)
dask_regressor = dlgbm.DaskLGBMRegressor( dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5, time_out=5,
...@@ -366,7 +380,10 @@ def test_regressor(output, client, listen_port): ...@@ -366,7 +380,10 @@ def test_regressor(output, client, listen_port):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_regressor_pred_contrib(output, client, listen_port): def test_regressor_pred_contrib(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output) X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)
dask_regressor = dlgbm.DaskLGBMRegressor( dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5, time_out=5,
...@@ -398,7 +415,10 @@ def test_regressor_pred_contrib(output, client, listen_port): ...@@ -398,7 +415,10 @@ def test_regressor_pred_contrib(output, client, listen_port):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9]) @pytest.mark.parametrize('alpha', [.1, .5, .9])
def test_regressor_quantile(output, client, listen_port, alpha): def test_regressor_quantile(output, client, listen_port, alpha):
X, y, w, dX, dy, dw = _create_data('regression', output=output) X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)
dask_regressor = dlgbm.DaskLGBMRegressor( dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port, local_listen_port=listen_port,
...@@ -459,17 +479,32 @@ def test_regressor_local_predict(client, listen_port): ...@@ -459,17 +479,32 @@ def test_regressor_local_predict(client, listen_port):
@pytest.mark.parametrize('group', [None, group_sizes]) @pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker(output, client, listen_port, group): def test_ranker(output, client, listen_port, group):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group) X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=group
)
# use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of # use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of
# serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210. # serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210.
dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner_type='data_parallel', dask_ranker = dlgbm.DaskLGBMRanker(
n_estimators=50, num_leaves=20, seed=42, min_child_samples=1) time_out=5,
local_listen_port=listen_port,
tree_learner_type='data_parallel',
n_estimators=50,
num_leaves=20,
seed=42,
min_child_samples=1
)
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client) dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX) rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute() rnkvec_dask = rnkvec_dask.compute()
local_ranker = lightgbm.LGBMRanker(n_estimators=50, num_leaves=20, seed=42, min_child_samples=1) local_ranker = lightgbm.LGBMRanker(
n_estimators=50,
num_leaves=20,
seed=42,
min_child_samples=1
)
local_ranker.fit(X, y, sample_weight=w, group=g) local_ranker.fit(X, y, sample_weight=w, group=g)
rnkvec_local = local_ranker.predict(X) rnkvec_local = local_ranker.predict(X)
...@@ -486,10 +521,20 @@ def test_ranker(output, client, listen_port, group): ...@@ -486,10 +521,20 @@ def test_ranker(output, client, listen_port, group):
@pytest.mark.parametrize('group', [None, group_sizes]) @pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker_local_predict(output, client, listen_port, group): def test_ranker_local_predict(output, client, listen_port, group):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group) X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=group
)
dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner='data', dask_ranker = dlgbm.DaskLGBMRanker(
n_estimators=10, num_leaves=10, seed=42, min_child_samples=1) time_out=5,
local_listen_port=listen_port,
tree_learner='data',
n_estimators=10,
num_leaves=10,
seed=42,
min_child_samples=1
)
dask_ranker = dask_ranker.fit(dX, dy, group=dg, client=client) dask_ranker = dask_ranker.fit(dX, dy, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX) rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute() rnkvec_dask = rnkvec_dask.compute()
...@@ -532,5 +577,11 @@ def test_errors(c, s, a, b): ...@@ -532,5 +577,11 @@ def test_errors(c, s, a, b):
df = dd.demo.make_timeseries() df = dd.demo.make_timeseries()
df = df.map_partitions(f, meta=df._meta) df = df.map_partitions(f, meta=df._meta)
with pytest.raises(Exception) as info: with pytest.raises(Exception) as info:
yield dlgbm._train(c, df, df.x, params={}, model_factory=lightgbm.LGBMClassifier) yield dlgbm._train(
client=c,
data=df,
label=df.x,
params={},
model_factory=lightgbm.LGBMClassifier
)
assert 'foo' in str(info.value) assert 'foo' in str(info.value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment