Unverified Commit b8fc476e authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] [python-package] use keyword args for internal function calls (#3755)



* [dask] use keyword args for internal function calls

* add missing comma

* Update python-package/lightgbm/dask.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* revert whitespace changes

* test style
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 477cbf37
# coding: utf-8
"""Distributed training with LightGBM and Dask.distributed.
This module enables you to perform distributed training with LightGBM on Dask.Array and Dask.DataFrame collections.
It is based on dask-xgboost package.
This module enables you to perform distributed training with LightGBM on
Dask.Array and Dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import logging
import socket
......@@ -145,10 +147,16 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
if is_ranker:
group = _concat(parts[-1])
weight = _concat(parts[2]) if len(parts) == 4 else None
if len(parts) == 4:
weight = _concat(parts[2])
else:
weight = None
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
else:
weight = _concat(parts[2]) if len(parts) == 3 else None
if len(parts) == 3:
weight = _concat(parts[2])
else:
weight = None
model.fit(data, y=label, sample_weight=weight, **kwargs)
finally:
......@@ -160,7 +168,10 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
def _split_to_parts(data, is_matrix):
parts = data.to_delayed()
if isinstance(parts, np.ndarray):
assert (parts.shape[1] == 1) if is_matrix else (parts.ndim == 1 or parts.shape[1] == 1)
if is_matrix:
assert parts.shape[1] == 1
else:
assert parts.ndim == 1 or parts.shape[1] == 1
parts = parts.flatten().tolist()
return parts
......@@ -189,10 +200,18 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
params = deepcopy(params)
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
weight_parts = _split_to_parts(sample_weight, is_matrix=False) if sample_weight is not None else None
group_parts = _split_to_parts(group, is_matrix=False) if group is not None else None
data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(data=label, is_matrix=False)
if sample_weight is not None:
weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
else:
weight_parts = None
if group is not None:
group_parts = _split_to_parts(data=group, is_matrix=False)
else:
group_parts = None
# choose between four options of (sample_weight, group) being (un)specified
if weight_parts is None and group_parts is None:
......@@ -265,15 +284,19 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
params.pop(num_thread_alias, None)
# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
futures_classifiers = [
client.submit(
_train_part,
model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
worker_address_to_port=worker_address_to_port,
time_out=params.get('time_out', 120),
return_model=(worker == master_worker),
**kwargs)
for worker, list_of_parts in worker_map.items()]
**kwargs
)
for worker, list_of_parts in worker_map.items()
]
results = client.gather(futures_classifiers)
results = [v for v in results if v]
......@@ -368,8 +391,17 @@ class _LGBMModel:
client = default_client()
params = self.get_params(True)
model = _train(client, data=X, label=y, params=params, model_factory=model_factory,
sample_weight=sample_weight, group=group, **kwargs)
model = _train(
client=client,
data=X,
label=y,
params=params,
model_factory=model_factory,
sample_weight=sample_weight,
group=group,
**kwargs
)
self.set_params(**model.get_params())
self._copy_extra_params(model, self)
......@@ -395,17 +427,37 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(LGBMClassifier, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
return self._fit(
model_factory=LGBMClassifier,
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)
fit.__doc__ = LGBMClassifier.fit.__doc__
def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict(self.to_local(), X, dtype=self.classes_.dtype, **kwargs)
return _predict(
model=self.to_local(),
data=X,
dtype=self.classes_.dtype,
**kwargs
)
predict.__doc__ = LGBMClassifier.predict.__doc__
def predict_proba(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(self.to_local(), X, pred_proba=True, **kwargs)
return _predict(
model=self.to_local(),
data=X,
pred_proba=True,
**kwargs
)
predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__
def to_local(self):
......@@ -423,12 +475,25 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):
def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(LGBMRegressor, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
return self._fit(
model_factory=LGBMRegressor,
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)
fit.__doc__ = LGBMRegressor.fit.__doc__
def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict(self.to_local(), X, **kwargs)
return _predict(
model=self.to_local(),
data=X,
**kwargs
)
predict.__doc__ = LGBMRegressor.predict.__doc__
def to_local(self):
......@@ -449,12 +514,22 @@ class DaskLGBMRanker(_LGBMModel, LGBMRanker):
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
return self._fit(LGBMRanker, X=X, y=y, sample_weight=sample_weight, group=group, client=client, **kwargs)
return self._fit(
model_factory=LGBMRanker,
X=X,
y=y,
sample_weight=sample_weight,
group=group,
client=client,
**kwargs
)
fit.__doc__ = LGBMRanker.fit.__doc__
def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs)
predict.__doc__ = LGBMRanker.predict.__doc__
def to_local(self):
......
......@@ -206,7 +206,11 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
)
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
......@@ -238,7 +242,11 @@ def test_classifier(output, centers, client, listen_port):
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_pred_contrib(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
)
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
......@@ -309,7 +317,10 @@ def test_training_does_not_fail_on_port_conflicts(client):
def test_classifier_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output='array')
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output='array'
)
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
......@@ -333,7 +344,10 @@ def test_classifier_local_predict(client, listen_port):
@pytest.mark.parametrize('output', data_output)
def test_regressor(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output)
X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)
dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
......@@ -366,7 +380,10 @@ def test_regressor(output, client, listen_port):
@pytest.mark.parametrize('output', data_output)
def test_regressor_pred_contrib(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output)
X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)
dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
......@@ -398,7 +415,10 @@ def test_regressor_pred_contrib(output, client, listen_port):
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9])
def test_regressor_quantile(output, client, listen_port, alpha):
X, y, w, dX, dy, dw = _create_data('regression', output=output)
X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)
dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
......@@ -459,17 +479,32 @@ def test_regressor_local_predict(client, listen_port):
@pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker(output, client, listen_port, group):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group)
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=group
)
# use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of
# serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210.
dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner_type='data_parallel',
n_estimators=50, num_leaves=20, seed=42, min_child_samples=1)
dask_ranker = dlgbm.DaskLGBMRanker(
time_out=5,
local_listen_port=listen_port,
tree_learner_type='data_parallel',
n_estimators=50,
num_leaves=20,
seed=42,
min_child_samples=1
)
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()
local_ranker = lightgbm.LGBMRanker(n_estimators=50, num_leaves=20, seed=42, min_child_samples=1)
local_ranker = lightgbm.LGBMRanker(
n_estimators=50,
num_leaves=20,
seed=42,
min_child_samples=1
)
local_ranker.fit(X, y, sample_weight=w, group=g)
rnkvec_local = local_ranker.predict(X)
......@@ -486,10 +521,20 @@ def test_ranker(output, client, listen_port, group):
@pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker_local_predict(output, client, listen_port, group):
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group)
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=group
)
dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner='data',
n_estimators=10, num_leaves=10, seed=42, min_child_samples=1)
dask_ranker = dlgbm.DaskLGBMRanker(
time_out=5,
local_listen_port=listen_port,
tree_learner='data',
n_estimators=10,
num_leaves=10,
seed=42,
min_child_samples=1
)
dask_ranker = dask_ranker.fit(dX, dy, group=dg, client=client)
rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute()
......@@ -532,5 +577,11 @@ def test_errors(c, s, a, b):
df = dd.demo.make_timeseries()
df = df.map_partitions(f, meta=df._meta)
with pytest.raises(Exception) as info:
yield dlgbm._train(c, df, df.x, params={}, model_factory=lightgbm.LGBMClassifier)
yield dlgbm._train(
client=c,
data=df,
label=df.x,
params={},
model_factory=lightgbm.LGBMClassifier
)
assert 'foo' in str(info.value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment