Unverified Commit 1f4a0842 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[tests][dask] simplify code in Dask tests (#4075)

* simplify Dask tests code

* enable CI

* disable CI
parent 39c85dd9
...@@ -131,7 +131,7 @@ def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs) ...@@ -131,7 +131,7 @@ def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs)
return X, y, w, g_rle, dX, dy, dw, dg return X, y, w, g_rle, dX, dy, dw, dg
def _create_data(objective, n_samples=100, output='array', chunk_size=50): def _create_data(objective, n_samples=100, output='array', chunk_size=50, **kwargs):
if objective.endswith('classification'): if objective.endswith('classification'):
if objective == 'binary-classification': if objective == 'binary-classification':
centers = [[-4, -4], [4, 4]] centers = [[-4, -4], [4, 4]]
...@@ -142,6 +142,13 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50): ...@@ -142,6 +142,13 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50):
X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42) X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42)
elif objective == 'regression': elif objective == 'regression':
X, y = make_regression(n_samples=n_samples, random_state=42) X, y = make_regression(n_samples=n_samples, random_state=42)
elif objective == 'ranking':
return _create_ranking_data(
n_samples=n_samples,
output=output,
chunk_size=chunk_size,
**kwargs
)
else: else:
raise ValueError("Unknown objective '%s'" % objective) raise ValueError("Unknown objective '%s'" % objective)
rnd = np.random.RandomState(42) rnd = np.random.RandomState(42)
...@@ -183,7 +190,7 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50): ...@@ -183,7 +190,7 @@ def _create_data(objective, n_samples=100, output='array', chunk_size=50):
else: else:
raise ValueError("Unknown output type '%s'" % output) raise ValueError("Unknown output type '%s'" % output)
return X, y, weights, dX, dy, dw return X, y, weights, None, dX, dy, dw, None
def _r2_score(dy_true, dy_pred): def _r2_score(dy_true, dy_pred):
...@@ -225,7 +232,7 @@ def _unpickle(filepath, serializer): ...@@ -225,7 +232,7 @@ def _unpickle(filepath, serializer):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier(output, task, client): def test_classifier(output, task, client):
X, y, w, dX, dy, dw = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task, objective=task,
output=output output=output
) )
...@@ -291,7 +298,7 @@ def test_classifier(output, task, client): ...@@ -291,7 +298,7 @@ def test_classifier(output, task, client):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier_pred_contrib(output, task, client): def test_classifier_pred_contrib(output, task, client):
X, y, w, dX, dy, dw = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task, objective=task,
output=output output=output
) )
...@@ -369,7 +376,7 @@ def test_find_random_open_port(client): ...@@ -369,7 +376,7 @@ def test_find_random_open_port(client):
def test_training_does_not_fail_on_port_conflicts(client): def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('binary-classification', output='array') _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
lightgbm_default_port = 12400 lightgbm_default_port = 12400
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
...@@ -393,7 +400,7 @@ def test_training_does_not_fail_on_port_conflicts(client): ...@@ -393,7 +400,7 @@ def test_training_does_not_fail_on_port_conflicts(client):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_regressor(output, client): def test_regressor(output, client):
X, y, w, dX, dy, dw = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression', objective='regression',
output=output output=output
) )
...@@ -468,7 +475,7 @@ def test_regressor(output, client): ...@@ -468,7 +475,7 @@ def test_regressor(output, client):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_regressor_pred_contrib(output, client): def test_regressor_pred_contrib(output, client):
X, y, w, dX, dy, dw = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression', objective='regression',
output=output output=output
) )
...@@ -518,7 +525,7 @@ def test_regressor_pred_contrib(output, client): ...@@ -518,7 +525,7 @@ def test_regressor_pred_contrib(output, client):
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9]) @pytest.mark.parametrize('alpha', [.1, .5, .9])
def test_regressor_quantile(output, client, alpha): def test_regressor_quantile(output, client, alpha):
X, y, w, dX, dy, dw = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression', objective='regression',
output=output output=output
) )
...@@ -567,18 +574,19 @@ def test_regressor_quantile(output, client, alpha): ...@@ -567,18 +574,19 @@ def test_regressor_quantile(output, client, alpha):
@pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical']) @pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical'])
@pytest.mark.parametrize('group', [None, group_sizes]) @pytest.mark.parametrize('group', [None, group_sizes])
def test_ranker(output, client, group): def test_ranker(output, client, group):
if output == 'dataframe-with-categorical': if output == 'dataframe-with-categorical':
X, y, w, g, dX, dy, dw, dg = _create_ranking_data( X, y, w, g, dX, dy, dw, dg = _create_data(
objective='ranking',
output=output, output=output,
group=group, group=group,
n_features=1, n_features=1,
n_informative=1 n_informative=1
) )
else: else:
X, y, w, g, dX, dy, dw, dg = _create_ranking_data( X, y, w, g, dX, dy, dw, dg = _create_data(
objective='ranking',
output=output, output=output,
group=group, group=group
) )
# rebalance small dask.Array dataset for better performance. # rebalance small dask.Array dataset for better performance.
...@@ -650,17 +658,11 @@ def test_ranker(output, client, group): ...@@ -650,17 +658,11 @@ def test_ranker(output, client, group):
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
def test_training_works_if_client_not_provided_or_set_after_construction(task, client): def test_training_works_if_client_not_provided_or_set_after_construction(task, client):
if task == 'ranking': _, _, _, _, dX, dy, _, dg = _create_data(
_, _, _, _, dX, dy, _, dg = _create_ranking_data( objective=task,
output='array', output='array',
group=None group=None
) )
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output='array',
)
dg = None
model_factory = task_to_dask_factory[task] model_factory = task_to_dask_factory[task]
params = { params = {
...@@ -723,182 +725,166 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c ...@@ -723,182 +725,166 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c
@pytest.mark.parametrize('set_client', [True, False]) @pytest.mark.parametrize('set_client', [True, False])
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, tmp_path): def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, tmp_path):
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1: with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1, Client(cluster1) as client1:
with Client(cluster1) as client1: # data on cluster1
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_data(
objective=task,
output='array',
group=None
)
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2, Client(cluster2) as client2:
# create identical data on cluster2
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_data(
objective=task,
output='array',
group=None
)
# data on cluster1 model_factory = task_to_dask_factory[task]
if task == 'ranking':
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data( params = {
output='array', "time_out": 5,
group=None "n_estimators": 1,
) "num_leaves": 2
}
# at this point, the result of default_client() is client2 since it was the most recently
# created. So setting client to client1 here to test that you can select a non-default client
assert default_client() == client2
if set_client:
params.update({"client": client1})
# unfitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model.client == client1
else: else:
X_1, _, _, dX_1, dy_1, _ = _create_data( assert dask_model.client is None
objective=task,
output='array', with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
) dask_model.client_
dg_1 = None
assert "client" not in local_model.get_params()
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2: assert getattr(local_model, "client", None) is None
with Client(cluster2) as client2:
tmp_file = str(tmp_path / "model-1.pkl")
# create identical data on cluster2 _pickle(
if task == 'ranking': obj=dask_model,
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data( filepath=tmp_file,
output='array', serializer=serializer
group=None )
) model_from_disk = _unpickle(
else: filepath=tmp_file,
X_2, _, _, dX_2, dy_2, _ = _create_data( serializer=serializer
objective=task, )
output='array',
) local_tmp_file = str(tmp_path / "local-model-1.pkl")
dg_2 = None _pickle(
obj=local_model,
model_factory = task_to_dask_factory[task] filepath=local_tmp_file,
serializer=serializer
params = { )
"time_out": 5, local_model_from_disk = _unpickle(
"n_estimators": 1, filepath=local_tmp_file,
"num_leaves": 2 serializer=serializer
} )
# at this point, the result of default_client() is client2 since it was the most recently assert model_from_disk.client is None
# created. So setting client to client1 here to test that you can select a non-default client
assert default_client() == client2 if set_client:
if set_client: assert dask_model.client == client1
params.update({"client": client1}) else:
assert dask_model.client is None
# unfitted model should survive pickling round trip, and pickling
# shouldn't have side effects on the model object with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model = model_factory(**params) dask_model.client_
local_model = dask_model.to_local()
if set_client: # client will always be None after unpickling
assert dask_model.client == client1 if set_client:
else: from_disk_params = model_from_disk.get_params()
assert dask_model.client is None from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): dask_params.pop("client", None)
dask_model.client_ assert from_disk_params == dask_params
else:
assert "client" not in local_model.get_params() assert model_from_disk.get_params() == dask_model.get_params()
assert getattr(local_model, "client", None) is None assert local_model_from_disk.get_params() == local_model.get_params()
tmp_file = str(tmp_path / "model-1.pkl") # fitted model should survive pickling round trip, and pickling
_pickle( # shouldn't have side effects on the model object
obj=dask_model, if set_client:
filepath=tmp_file, dask_model.fit(dX_1, dy_1, group=dg_1)
serializer=serializer else:
) dask_model.fit(dX_2, dy_2, group=dg_2)
model_from_disk = _unpickle( local_model = dask_model.to_local()
filepath=tmp_file,
serializer=serializer assert "client" not in local_model.get_params()
) with pytest.raises(AttributeError):
local_model.client
local_tmp_file = str(tmp_path / "local-model-1.pkl") local_model.client_
_pickle(
obj=local_model, tmp_file2 = str(tmp_path / "model-2.pkl")
filepath=local_tmp_file, _pickle(
serializer=serializer obj=dask_model,
) filepath=tmp_file2,
local_model_from_disk = _unpickle( serializer=serializer
filepath=local_tmp_file, )
serializer=serializer fitted_model_from_disk = _unpickle(
) filepath=tmp_file2,
serializer=serializer
assert model_from_disk.client is None )
if set_client: local_tmp_file2 = str(tmp_path / "local-model-2.pkl")
assert dask_model.client == client1 _pickle(
else: obj=local_model,
assert dask_model.client is None filepath=local_tmp_file2,
serializer=serializer
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): )
dask_model.client_ local_fitted_model_from_disk = _unpickle(
filepath=local_tmp_file2,
# client will always be None after unpickling serializer=serializer
if set_client: )
from_disk_params = model_from_disk.get_params()
from_disk_params.pop("client", None) if set_client:
dask_params = dask_model.get_params() assert dask_model.client == client1
dask_params.pop("client", None) assert dask_model.client_ == client1
assert from_disk_params == dask_params else:
else: assert dask_model.client is None
assert model_from_disk.get_params() == dask_model.get_params() assert dask_model.client_ == default_client()
assert local_model_from_disk.get_params() == local_model.get_params() assert dask_model.client_ == client2
# fitted model should survive pickling round trip, and pickling assert isinstance(fitted_model_from_disk, model_factory)
# shouldn't have side effects on the model object assert fitted_model_from_disk.client is None
if set_client: assert fitted_model_from_disk.client_ == default_client()
dask_model.fit(dX_1, dy_1, group=dg_1) assert fitted_model_from_disk.client_ == client2
else:
dask_model.fit(dX_2, dy_2, group=dg_2) # client will always be None after unpickling
local_model = dask_model.to_local() if set_client:
from_disk_params = fitted_model_from_disk.get_params()
assert "client" not in local_model.get_params() from_disk_params.pop("client", None)
with pytest.raises(AttributeError): dask_params = dask_model.get_params()
local_model.client dask_params.pop("client", None)
local_model.client_ assert from_disk_params == dask_params
else:
tmp_file2 = str(tmp_path / "model-2.pkl") assert fitted_model_from_disk.get_params() == dask_model.get_params()
_pickle( assert local_fitted_model_from_disk.get_params() == local_model.get_params()
obj=dask_model,
filepath=tmp_file2, if set_client:
serializer=serializer preds_orig = dask_model.predict(dX_1).compute()
) preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute()
fitted_model_from_disk = _unpickle( preds_orig_local = local_model.predict(X_1)
filepath=tmp_file2, preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1)
serializer=serializer else:
) preds_orig = dask_model.predict(dX_2).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute()
local_tmp_file2 = str(tmp_path / "local-model-2.pkl") preds_orig_local = local_model.predict(X_2)
_pickle( preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2)
obj=local_model,
filepath=local_tmp_file2, assert_eq(preds_orig, preds_loaded_model)
serializer=serializer assert_eq(preds_orig_local, preds_loaded_model_local)
)
local_fitted_model_from_disk = _unpickle(
filepath=local_tmp_file2,
serializer=serializer
)
if set_client:
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2
# client will always be None after unpickling
if set_client:
from_disk_params = fitted_model_from_disk.get_params()
from_disk_params.pop("client", None)
dask_params = dask_model.get_params()
dask_params.pop("client", None)
assert from_disk_params == dask_params
else:
assert fitted_model_from_disk.get_params() == dask_model.get_params()
assert local_fitted_model_from_disk.get_params() == local_model.get_params()
if set_client:
preds_orig = dask_model.predict(dX_1).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute()
preds_orig_local = local_model.predict(X_1)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1)
else:
preds_orig = dask_model.predict(dX_2).compute()
preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute()
preds_orig_local = local_model.predict(X_2)
preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2)
assert_eq(preds_orig, preds_loaded_model)
assert_eq(preds_orig_local, preds_loaded_model_local)
def test_warns_and_continues_on_unrecognized_tree_learner(client): def test_warns_and_continues_on_unrecognized_tree_learner(client):
...@@ -971,18 +957,11 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas ...@@ -971,18 +957,11 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas
return collection.rechunk(*collection.shape) return collection.rechunk(*collection.shape)
return collection.repartition(npartitions=1) return collection.repartition(npartitions=1)
if task == 'ranking': X, y, w, g, dX, dy, dw, dg = _create_data(
X, y, w, g, dX, dy, dw, dg = _create_ranking_data( objective=task,
output=output, output=output,
group=None group=None
) )
else:
X, y, w, dX, dy, dw = _create_data(
objective=task,
output=output
)
g = None
dg = None
dask_model_factory = task_to_dask_factory[task] dask_model_factory = task_to_dask_factory[task]
local_model_factory = task_to_local_factory[task] local_model_factory = task_to_local_factory[task]
...@@ -1026,19 +1005,12 @@ def test_network_params_not_required_but_respected_if_given(client, task, output ...@@ -1026,19 +1005,12 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
client.wait_for_workers(2) client.wait_for_workers(2)
if task == 'ranking': _, _, _, _, dX, dy, _, dg = _create_data(
_, _, _, _, dX, dy, _, dg = _create_ranking_data( objective=task,
output=output, output=output,
group=None, chunk_size=10,
chunk_size=10, group=None
) )
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output=output,
chunk_size=10,
)
dg = None
dask_model_factory = task_to_dask_factory[task] dask_model_factory = task_to_dask_factory[task]
...@@ -1097,19 +1069,12 @@ def test_machines_should_be_used_if_provided(task, output): ...@@ -1097,19 +1069,12 @@ def test_machines_should_be_used_if_provided(task, output):
pytest.skip('LGBMRanker is not currently tested on sparse matrices') pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with LocalCluster(n_workers=2) as cluster, Client(cluster) as client: with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:
if task == 'ranking': _, _, _, _, dX, dy, _, dg = _create_data(
_, _, _, _, dX, dy, _, dg = _create_ranking_data( objective=task,
output=output, output=output,
group=None, chunk_size=10,
chunk_size=10, group=None
) )
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output=output,
chunk_size=10,
)
dg = None
dask_model_factory = task_to_dask_factory[task] dask_model_factory = task_to_dask_factory[task]
...@@ -1205,17 +1170,11 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array( ...@@ -1205,17 +1170,11 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(
task, task,
client, client,
): ):
if task == 'ranking': _, _, _, _, dX, dy, dw, dg = _create_data(
_, _, _, _, dX, dy, dw, dg = _create_ranking_data( objective=task,
output='dataframe', output='dataframe',
group=None group=None
) )
else:
_, _, _, dX, dy, dw = _create_data(
objective=task,
output='dataframe',
)
dg = None
model_factory = task_to_dask_factory[task] model_factory = task_to_dask_factory[task]
...@@ -1242,17 +1201,11 @@ def test_init_score(task, output, client): ...@@ -1242,17 +1201,11 @@ def test_init_score(task, output, client):
if task == 'ranking' and output == 'scipy_csr_matrix': if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices') pytest.skip('LGBMRanker is not currently tested on sparse matrices')
if task == 'ranking': _, _, _, _, dX, dy, dw, dg = _create_data(
_, _, _, _, dX, dy, dw, dg = _create_ranking_data( objective=task,
output=output, output=output,
group=None group=None
) )
else:
_, _, _, dX, dy, dw = _create_data(
objective=task,
output=output,
)
dg = None
model_factory = task_to_dask_factory[task] model_factory = task_to_dask_factory[task]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment