Unverified Commit d32ee23a authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[ci] remove output parametrization from two Dask tests (#4123)

* Update test_dask.py

* Update test_dask.py
parent e98da99d
...@@ -994,16 +994,12 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas ...@@ -994,16 +994,12 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output) def test_network_params_not_required_but_respected_if_given(client, task, listen_port):
def test_network_params_not_required_but_respected_if_given(client, task, output, listen_port):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
client.wait_for_workers(2) client.wait_for_workers(2)
_, _, _, _, dX, dy, _, dg = _create_data( _, _, _, _, dX, dy, _, dg = _create_data(
objective=task, objective=task,
output=output, output='array',
chunk_size=10, chunk_size=10,
group=None group=None
) )
...@@ -1011,7 +1007,6 @@ def test_network_params_not_required_but_respected_if_given(client, task, output ...@@ -1011,7 +1007,6 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
dask_model_factory = task_to_dask_factory[task] dask_model_factory = task_to_dask_factory[task]
# rebalance data to be sure that each worker has a piece of the data # rebalance data to be sure that each worker has a piece of the data
if output == 'array':
client.rebalance() client.rebalance()
# model 1 - no network parameters given # model 1 - no network parameters given
...@@ -1059,15 +1054,11 @@ def test_network_params_not_required_but_respected_if_given(client, task, output ...@@ -1059,15 +1054,11 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output) def test_machines_should_be_used_if_provided(task):
def test_machines_should_be_used_if_provided(task, output):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with LocalCluster(n_workers=2) as cluster, Client(cluster) as client: with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:
_, _, _, _, dX, dy, _, dg = _create_data( _, _, _, _, dX, dy, _, dg = _create_data(
objective=task, objective=task,
output=output, output='array',
chunk_size=10, chunk_size=10,
group=None group=None
) )
...@@ -1075,7 +1066,6 @@ def test_machines_should_be_used_if_provided(task, output): ...@@ -1075,7 +1066,6 @@ def test_machines_should_be_used_if_provided(task, output):
dask_model_factory = task_to_dask_factory[task] dask_model_factory = task_to_dask_factory[task]
# rebalance data to be sure that each worker has a piece of the data # rebalance data to be sure that each worker has a piece of the data
if output == 'array':
client.rebalance() client.rebalance()
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment