Unverified Commit 1f4a0842 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[tests][dask] simplify code in Dask tests (#4075)

* simplify Dask tests code

* enable CI

* disable CI
parent 39c85dd9
......@@ -131,7 +131,7 @@ def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs)
return X, y, w, g_rle, dX, dy, dw, dg
def _create_data(objective, n_samples=100, output='array', chunk_size=50):
def _create_data(objective, n_samples=100, output='array', chunk_size=50, **kwargs):
if objective.endswith('classification'):
if objective == 'binary-classification':
centers = [[-4, -4], [4, 4]]
......@@ -142,6 +142,13 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50):
X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42)
elif objective == 'regression':
X, y = make_regression(n_samples=n_samples, random_state=42)
elif objective == 'ranking':
return _create_ranking_data(
n_samples=n_samples,
output=output,
chunk_size=chunk_size,
**kwargs
)
else:
raise ValueError("Unknown objective '%s'" % objective)
rnd = np.random.RandomState(42)
......@@ -183,7 +190,7 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50):
else:
raise ValueError("Unknown output type '%s'" % output)
return X, y, weights, dX, dy, dw
return X, y, weights, None, dX, dy, dw, None
def _r2_score(dy_true, dy_pred):
......@@ -225,7 +232,7 @@ def _unpickle(filepath, serializer):
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier(output, task, client):
X, y, w, dX, dy, dw = _create_data(
X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task,
output=output
)
......@@ -291,7 +298,7 @@ def test_classifier(output, task, client):
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier_pred_contrib(output, task, client):
X, y, w, dX, dy, dw = _create_data(
X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task,
output=output
)
......@@ -369,7 +376,7 @@ def test_find_random_open_port(client):
def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('binary-classification', output='array')
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
lightgbm_default_port = 12400
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
......@@ -393,7 +400,7 @@ def test_training_does_not_fail_on_port_conflicts(client):
@pytest.mark.parametrize('output', data_output)
def test_regressor(output, client):
X, y, w, dX, dy, dw = _create_data(
X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression',
output=output
)
......@@ -468,7 +475,7 @@ def test_regressor(output, client):
@pytest.mark.parametrize('output', data_output)
def test_regressor_pred_contrib(output, client):
X, y, w, dX, dy, dw = _create_data(
X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression',
output=output
)
......@@ -518,7 +525,7 @@ def test_regressor_pred_contrib(output, client):
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9])
def test_regressor_quantile(output, client, alpha):
X, y, w, dX, dy, dw = _create_data(
X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression',
output=output
)
......@@ -567,18 +574,19 @@ def test_regressor_quantile(output, client, alpha):
@pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical'])
@pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker(output, client, group):
if output == 'dataframe-with-categorical':
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
X, y, w, g, dX, dy, dw, dg = _create_data(
objective='ranking',
output=output,
group=group,
n_features=1,
n_informative=1
)
else:
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
X, y, w, g, dX, dy, dw, dg = _create_data(
objective='ranking',
output=output,
group=group,
group=group
)
# rebalance small dask.Array dataset for better performance.
......@@ -650,17 +658,11 @@ def test_ranker(output, client, group):
@pytest.mark.parametrize('task', tasks)
def test_training_works_if_client_not_provided_or_set_after_construction(task, client):
if task == 'ranking':
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
output='array',
group=None
)
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output='array',
)
dg = None
_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output='array',
group=None
)
model_factory = task_to_dask_factory[task]
params = {
......@@ -723,182 +725,166 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c
@pytest.mark.parametrize('set_client', [True, False])
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, tmp_path):
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1:
with Client(cluster1) as client1:
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1, Client(cluster1) as client1:
# data on cluster1
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_data(
objective=task,
output='array',
group=None
)
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2, Client(cluster2) as client2:
# create identical data on cluster2
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_data(
objective=task,
output='array',
group=None
)
# data on cluster1
if task == 'ranking':
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data(
output='array',
group=None
)
model_factory = task_to_dask_factory[task]
params = {
"time_out": 5,
"n_estimators": 1,
"num_leaves": 2
}
# at this point, the result of default_client() is client2 since it was the most recently
# created. So setting client to client1 here to test that you can select a non-default client
assert default_client() == client2
if set_client:
params.update({"client": client1})
# unfitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model.client == client1
else:
X_1, _, _, dX_1, dy_1, _ = _create_data(
objective=task,
output='array',
)
dg_1 = None
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2:
with Client(cluster2) as client2:
# create identical data on cluster2
if task == 'ranking':
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data(
output='array',
group=None
)
else:
X_2, _, _, dX_2, dy_2, _ = _create_data(
objective=task,
output='array',
)
dg_2 = None
model_factory = task_to_dask_factory[task]
params = {
"time_out": 5,
"n_estimators": 1,
"num_leaves": 2
}
# at this point, the result of default_client() is client2 since it was the most recently
# created. So setting client to client1 here to test that you can select a non-default client
assert default_client() == client2
if set_client:
params.update({"client": client1})
# unfitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model.client == client1
else:
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
assert "client" not in local_model.get_params()
assert getattr(local_model, "client", None) is None
tmp_file = str(tmp_path / "model-1.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file,
serializer=serializer
)
model_from_disk = _unpickle(
filepath=tmp_file,
serializer=serializer
)
local_tmp_file = str(tmp_path / "local-model-1.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file,
serializer=serializer
)
local_model_from_disk = _unpickle(
filepath=local_tmp_file,
serializer=serializer
)
assert model_from_disk.client is None
if set_client:
assert dask_model.client == client1
else:
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
# client will always be None after unpickling
if set_client:
from_disk_params = model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert model_from_disk.get_params() == dask_model.get_params()
assert local_model_from_disk.get_params() == local_model.get_params()
# fitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
if set_client:
dask_model.fit(dX_1, dy_1, group=dg_1)
else:
dask_model.fit(dX_2, dy_2, group=dg_2)
local_model = dask_model.to_local()
assert "client" not in local_model.get_params()
with pytest.raises(AttributeError):
local_model.client
local_model.client_
tmp_file2 = str(tmp_path / "model-2.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file2,
serializer=serializer
)
fitted_model_from_disk = _unpickle(
filepath=tmp_file2,
serializer=serializer
)
local_tmp_file2 = str(tmp_path / "local-model-2.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file2,
serializer=serializer
)
local_fitted_model_from_disk = _unpickle(
filepath=local_tmp_file2,
serializer=serializer
)
if set_client:
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2
# client will always be None after unpickling
if set_client:
from_disk_params = fitted_model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert fitted_model_from_disk.get_params() == dask_model.get_params()
assert local_fitted_model_from_disk.get_params() == local_model.get_params()
if set_client:
preds_orig = dask_model.predict(dX_1).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute()
preds_orig_local = local_model.predict(X_1)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1)
else:
preds_orig = dask_model.predict(dX_2).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute()
preds_orig_local = local_model.predict(X_2)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2)
assert_eq(preds_orig, preds_loaded_model)
assert_eq(preds_orig_local, preds_loaded_model_local)
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
assert "client" not in local_model.get_params()
assert getattr(local_model, "client", None) is None
tmp_file = str(tmp_path / "model-1.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file,
serializer=serializer
)
model_from_disk = _unpickle(
filepath=tmp_file,
serializer=serializer
)
local_tmp_file = str(tmp_path / "local-model-1.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file,
serializer=serializer
)
local_model_from_disk = _unpickle(
filepath=local_tmp_file,
serializer=serializer
)
assert model_from_disk.client is None
if set_client:
assert dask_model.client == client1
else:
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
# client will always be None after unpickling
if set_client:
from_disk_params = model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert model_from_disk.get_params() == dask_model.get_params()
assert local_model_from_disk.get_params() == local_model.get_params()
# fitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
if set_client:
dask_model.fit(dX_1, dy_1, group=dg_1)
else:
dask_model.fit(dX_2, dy_2, group=dg_2)
local_model = dask_model.to_local()
assert "client" not in local_model.get_params()
with pytest.raises(AttributeError):
local_model.client
local_model.client_
tmp_file2 = str(tmp_path / "model-2.pkl")
_pickle(
obj=dask_model,
filepath=tmp_file2,
serializer=serializer
)
fitted_model_from_disk = _unpickle(
filepath=tmp_file2,
serializer=serializer
)
local_tmp_file2 = str(tmp_path / "local-model-2.pkl")
_pickle(
obj=local_model,
filepath=local_tmp_file2,
serializer=serializer
)
local_fitted_model_from_disk = _unpickle(
filepath=local_tmp_file2,
serializer=serializer
)
if set_client:
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2
# client will always be None after unpickling
if set_client:
from_disk_params = fitted_model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert fitted_model_from_disk.get_params() == dask_model.get_params()
assert local_fitted_model_from_disk.get_params() == local_model.get_params()
if set_client:
preds_orig = dask_model.predict(dX_1).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute()
preds_orig_local = local_model.predict(X_1)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1)
else:
preds_orig = dask_model.predict(dX_2).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute()
preds_orig_local = local_model.predict(X_2)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2)
assert_eq(preds_orig, preds_loaded_model)
assert_eq(preds_orig_local, preds_loaded_model_local)
def test_warns_and_continues_on_unrecognized_tree_learner(client):
......@@ -971,18 +957,11 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas
return collection.rechunk(*collection.shape)
return collection.repartition(npartitions=1)
if task == 'ranking':
X, y, w, g, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=None
)
else:
X, y, w, dX, dy, dw = _create_data(
objective=task,
output=output
)
g = None
dg = None
X, y, w, g, dX, dy, dw, dg = _create_data(
objective=task,
output=output,
group=None
)
dask_model_factory = task_to_dask_factory[task]
local_model_factory = task_to_local_factory[task]
......@@ -1026,19 +1005,12 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
client.wait_for_workers(2)
if task == 'ranking':
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
output=output,
group=None,
chunk_size=10,
)
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output=output,
chunk_size=10,
)
dg = None
_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
chunk_size=10,
group=None
)
dask_model_factory = task_to_dask_factory[task]
......@@ -1097,19 +1069,12 @@ def test_machines_should_be_used_if_provided(task, output):
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:
if task == 'ranking':
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
output=output,
group=None,
chunk_size=10,
)
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output=output,
chunk_size=10,
)
dg = None
_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
chunk_size=10,
group=None
)
dask_model_factory = task_to_dask_factory[task]
......@@ -1205,17 +1170,11 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(
task,
client,
):
if task == 'ranking':
_, _, _, _, dX, dy, dw, dg = _create_ranking_data(
output='dataframe',
group=None
)
else:
_, _, _, dX, dy, dw = _create_data(
objective=task,
output='dataframe',
)
dg = None
_, _, _, _, dX, dy, dw, dg = _create_data(
objective=task,
output='dataframe',
group=None
)
model_factory = task_to_dask_factory[task]
......@@ -1242,17 +1201,11 @@ def test_init_score(task, output, client):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
if task == 'ranking':
_, _, _, _, dX, dy, dw, dg = _create_ranking_data(
output=output,
group=None
)
else:
_, _, _, dX, dy, dw = _create_data(
objective=task,
output=output,
)
dg = None
_, _, _, _, dX, dy, dw, dg = _create_data(
objective=task,
output=output,
group=None
)
model_factory = task_to_dask_factory[task]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment