Unverified Commit 93f2da43 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[tests][dask] fix workers without data test (fixes #5537) (#5544)

parent 2d4654a1
......@@ -11,6 +11,7 @@ from sys import platform
from urllib.parse import urlparse
import pytest
from sklearn.metrics import accuracy_score, r2_score
import lightgbm as lgb
......@@ -75,6 +76,13 @@ def cluster2():
dask_cluster.close()
@pytest.fixture(scope='module')
def cluster_three_workers():
dask_cluster = LocalCluster(n_workers=3, threads_per_worker=1, dashboard_address=None)
yield dask_cluster
dask_cluster.close()
@pytest.fixture()
def listen_port():
listen_port.port += 10
......@@ -1503,56 +1511,54 @@ def test_errors(cluster):
@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster):
pytest.skip("skipping due to timeout issues discussed in https://github.com/microsoft/LightGBM/pull/5510")
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster_three_workers):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with Client(cluster) as client:
def collection_to_single_partition(collection):
"""Merge the parts of a Dask collection into a single partition."""
if collection is None:
return
if isinstance(collection, da.Array):
return collection.rechunk(*collection.shape)
return collection.repartition(npartitions=1)
X, y, w, g, dX, dy, dw, dg = _create_data(
with Client(cluster_three_workers) as client:
_, y, _, _, dX, dy, dw, dg = _create_data(
objective=task,
output=output,
group=None
group=None,
n_samples=1_000,
chunk_size=200,
)
dask_model_factory = task_to_dask_factory[task]
local_model_factory = task_to_local_factory[task]
dX = collection_to_single_partition(dX)
dy = collection_to_single_partition(dy)
dw = collection_to_single_partition(dw)
dg = collection_to_single_partition(dg)
workers = list(client.scheduler_info()['workers'].keys())
assert len(workers) == 3
first_two_workers = workers[:2]
n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
assert dX.npartitions == 1
dX = client.persist(dX, workers=first_two_workers)
dy = client.persist(dy, workers=first_two_workers)
dw = client.persist(dw, workers=first_two_workers)
wait([dX, dy, dw])
workers_with_data = set()
for coll in (dX, dy, dw):
for with_data in client.who_has(coll).values():
workers_with_data.update(with_data)
assert workers[2] not in with_data
assert len(workers_with_data) == 2
params = {
'time_out': 5,
'random_state': 42,
'num_leaves': 10
'num_leaves': 10,
'n_estimators': 20,
}
dask_model = dask_model_factory(tree='data', client=client, **params)
dask_model.fit(dX, dy, group=dg, sample_weight=dw)
dask_preds = dask_model.predict(dX).compute()
local_model = local_model_factory(**params)
if task == 'ranking':
local_model.fit(X, y, group=g, sample_weight=w)
if task == 'regression':
score = r2_score(y, dask_preds)
elif task.endswith('classification'):
score = accuracy_score(y, dask_preds)
else:
local_model.fit(X, y, sample_weight=w)
local_preds = local_model.predict(X)
assert assert_eq(dask_preds, local_preds)
score = spearmanr(dask_preds, y).correlation
assert score > 0.9
@pytest.mark.parametrize('task', tasks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment