Unverified Commit d32ee23a authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[ci] remove output parametrization from two Dask tests (#4123)

* Update test_dask.py

* Update test_dask.py
parent e98da99d
......@@ -994,16 +994,12 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas
@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_network_params_not_required_but_respected_if_given(client, task, output, listen_port):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
def test_network_params_not_required_but_respected_if_given(client, task, listen_port):
client.wait_for_workers(2)
_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
output='array',
chunk_size=10,
group=None
)
......@@ -1011,8 +1007,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
dask_model_factory = task_to_dask_factory[task]
# rebalance data to be sure that each worker has a piece of the data
if output == 'array':
client.rebalance()
client.rebalance()
# model 1 - no network parameters given
dask_model1 = dask_model_factory(
......@@ -1059,15 +1054,11 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_machines_should_be_used_if_provided(task, output):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
def test_machines_should_be_used_if_provided(task):
with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:
_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
output='array',
chunk_size=10,
group=None
)
......@@ -1075,8 +1066,7 @@ def test_machines_should_be_used_if_provided(task, output):
dask_model_factory = task_to_dask_factory[task]
# rebalance data to be sure that each worker has a piece of the data
if output == 'array':
client.rebalance()
client.rebalance()
n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment