Unverified Commit 93f2da43 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[tests][dask] fix workers without data test (fixes #5537) (#5544)

parent 2d4654a1
...@@ -11,6 +11,7 @@ from sys import platform ...@@ -11,6 +11,7 @@ from sys import platform
from urllib.parse import urlparse from urllib.parse import urlparse
import pytest import pytest
from sklearn.metrics import accuracy_score, r2_score
import lightgbm as lgb import lightgbm as lgb
...@@ -75,6 +76,13 @@ def cluster2(): ...@@ -75,6 +76,13 @@ def cluster2():
dask_cluster.close() dask_cluster.close()
@pytest.fixture(scope='module')
def cluster_three_workers():
dask_cluster = LocalCluster(n_workers=3, threads_per_worker=1, dashboard_address=None)
yield dask_cluster
dask_cluster.close()
@pytest.fixture() @pytest.fixture()
def listen_port(): def listen_port():
listen_port.port += 10 listen_port.port += 10
...@@ -1503,56 +1511,54 @@ def test_errors(cluster): ...@@ -1503,56 +1511,54 @@ def test_errors(cluster):
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster): def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster_three_workers):
pytest.skip("skipping due to timeout issues discussed in https://github.com/microsoft/LightGBM/pull/5510")
if task == 'ranking' and output == 'scipy_csr_matrix': if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices') pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with Client(cluster) as client: with Client(cluster_three_workers) as client:
def collection_to_single_partition(collection): _, y, _, _, dX, dy, dw, dg = _create_data(
"""Merge the parts of a Dask collection into a single partition."""
if collection is None:
return
if isinstance(collection, da.Array):
return collection.rechunk(*collection.shape)
return collection.repartition(npartitions=1)
X, y, w, g, dX, dy, dw, dg = _create_data(
objective=task, objective=task,
output=output, output=output,
group=None group=None,
n_samples=1_000,
chunk_size=200,
) )
dask_model_factory = task_to_dask_factory[task] dask_model_factory = task_to_dask_factory[task]
local_model_factory = task_to_local_factory[task]
dX = collection_to_single_partition(dX) workers = list(client.scheduler_info()['workers'].keys())
dy = collection_to_single_partition(dy) assert len(workers) == 3
dw = collection_to_single_partition(dw) first_two_workers = workers[:2]
dg = collection_to_single_partition(dg)
n_workers = len(client.scheduler_info()['workers']) dX = client.persist(dX, workers=first_two_workers)
assert n_workers > 1 dy = client.persist(dy, workers=first_two_workers)
assert dX.npartitions == 1 dw = client.persist(dw, workers=first_two_workers)
wait([dX, dy, dw])
workers_with_data = set()
for coll in (dX, dy, dw):
for with_data in client.who_has(coll).values():
workers_with_data.update(with_data)
assert workers[2] not in with_data
assert len(workers_with_data) == 2
params = { params = {
'time_out': 5, 'time_out': 5,
'random_state': 42, 'random_state': 42,
'num_leaves': 10 'num_leaves': 10,
'n_estimators': 20,
} }
dask_model = dask_model_factory(tree='data', client=client, **params) dask_model = dask_model_factory(tree='data', client=client, **params)
dask_model.fit(dX, dy, group=dg, sample_weight=dw) dask_model.fit(dX, dy, group=dg, sample_weight=dw)
dask_preds = dask_model.predict(dX).compute() dask_preds = dask_model.predict(dX).compute()
if task == 'regression':
local_model = local_model_factory(**params) score = r2_score(y, dask_preds)
if task == 'ranking': elif task.endswith('classification'):
local_model.fit(X, y, group=g, sample_weight=w) score = accuracy_score(y, dask_preds)
else: else:
local_model.fit(X, y, sample_weight=w) score = spearmanr(dask_preds, y).correlation
local_preds = local_model.predict(X) assert score > 0.9
assert assert_eq(dask_preds, local_preds)
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment