Unverified Commit 3ab6bbf9 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[tests][dask] simplify fit calls in Dask tests (#4018)

* simplify fit calls in Dask tests

* Update .vsts-ci.yml

* Update .vsts-ci.yml
parent 5dacd603
...@@ -1046,10 +1046,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output ...@@ -1046,10 +1046,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
n_estimators=5, n_estimators=5,
num_leaves=5, num_leaves=5,
) )
if task == 'ranking': dask_model1.fit(dX, dy, group=dg)
dask_model1.fit(dX, dy, group=dg)
else:
dask_model1.fit(dX, dy)
assert dask_model1.fitted_ assert dask_model1.fitted_
params = dask_model1.get_params() params = dask_model1.get_params()
assert 'local_listen_port' not in params assert 'local_listen_port' not in params
...@@ -1067,10 +1064,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output ...@@ -1067,10 +1064,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
]), ]),
) )
if task == 'ranking': dask_model2.fit(dX, dy, group=dg)
dask_model2.fit(dX, dy, group=dg)
else:
dask_model2.fit(dX, dy)
assert dask_model2.fitted_ assert dask_model2.fitted_
params = dask_model2.get_params() params = dask_model2.get_params()
assert 'local_listen_port' not in params assert 'local_listen_port' not in params
...@@ -1086,10 +1080,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output ...@@ -1086,10 +1080,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
) )
error_msg = "has multiple Dask worker processes running on it" error_msg = "has multiple Dask worker processes running on it"
with pytest.raises(lgb.basic.LightGBMError, match=error_msg): with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
if task == 'ranking': dask_model3.fit(dX, dy, group=dg)
dask_model3.fit(dX, dy, group=dg)
else:
dask_model3.fit(dX, dy)
client.close(timeout=CLIENT_CLOSE_TIMEOUT) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
...@@ -1141,10 +1132,7 @@ def test_machines_should_be_used_if_provided(task, output): ...@@ -1141,10 +1132,7 @@ def test_machines_should_be_used_if_provided(task, output):
with pytest.raises(lgb.basic.LightGBMError, match=error_msg): with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', open_ports[0])) s.bind(('127.0.0.1', open_ports[0]))
if task == 'ranking': dask_model.fit(dX, dy, group=dg)
dask_model.fit(dX, dy, group=dg)
else:
dask_model.fit(dX, dy)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1232,6 +1220,7 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array( ...@@ -1232,6 +1220,7 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(
model = model_factory(**params) model = model_factory(**params)
model.fit(dX, dy_col_array, sample_weight=dw, group=dg) model.fit(dX, dy_col_array, sample_weight=dw, group=dg)
assert model.fitted_ assert model.fitted_
client.close(timeout=CLIENT_CLOSE_TIMEOUT) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment