Unverified Commit 7b47ab8f authored by jmoralez's avatar jmoralez Committed by GitHub
Browse files

[dask] test training when a worker has no data (#3897)

* include test for training when a worker has no data

* test single partition against local model for all tasks and outputs

* remove futures_of

* include james' comments

* remove product import
parent d4b6257c
......@@ -35,6 +35,7 @@ from .utils import make_ranking
# see https://distributed.dask.org/en/latest/api.html#distributed.Client.close
CLIENT_CLOSE_TIMEOUT = 120
tasks = ['classification', 'regression', 'ranking']
data_output = ['array', 'scipy_csr_matrix', 'dataframe', 'dataframe-with-categorical']
data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]]
group_sizes = [5, 5, 5, 10, 10, 10, 20, 20, 20, 50, 50]
......@@ -647,7 +648,7 @@ def test_ranker(output, client, listen_port, group):
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
@pytest.mark.parametrize('task', tasks)
def test_training_works_if_client_not_provided_or_set_after_construction(task, listen_port, client):
if task == 'ranking':
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
......@@ -723,7 +724,7 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle'])
@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking'])
@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('set_client', [True, False])
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, tmp_path):
......@@ -992,6 +993,72 @@ def test_errors(c, s, a, b):
assert 'foo' in str(info.value)
@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, task, output):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
def collection_to_single_partition(collection):
"""Merge the parts of a Dask collection into a single partition."""
if collection is None:
return
if isinstance(collection, da.Array):
return collection.rechunk(*collection.shape)
return collection.repartition(npartitions=1)
if task == 'ranking':
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=None
)
dask_model_factory = lgb.DaskLGBMRanker
local_model_factory = lgb.LGBMRanker
else:
X, y, w, dX, dy, dw = _create_data(
objective=task,
output=output
)
g = None
dg = None
if task == 'classification':
dask_model_factory = lgb.DaskLGBMClassifier
local_model_factory = lgb.LGBMClassifier
elif task == 'regression':
dask_model_factory = lgb.DaskLGBMRegressor
local_model_factory = lgb.LGBMRegressor
dX = collection_to_single_partition(dX)
dy = collection_to_single_partition(dy)
dw = collection_to_single_partition(dw)
dg = collection_to_single_partition(dg)
n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
assert dX.npartitions == 1
params = {
'time_out': 5,
'random_state': 42,
'num_leaves': 10
}
dask_model = dask_model_factory(tree='data', client=client, **params)
dask_model.fit(dX, dy, group=dg, sample_weight=dw)
dask_preds = dask_model.predict(dX).compute()
local_model = local_model_factory(**params)
if task == 'ranking':
local_model.fit(X, y, group=g, sample_weight=w)
else:
local_model.fit(X, y, sample_weight=w)
local_preds = local_model.predict(X)
assert assert_eq(dask_preds, local_preds)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize(
"classes",
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment