Unverified Commit 965b9fc9 authored by jmoralez's avatar jmoralez Committed by GitHub
Browse files

[tests][dask] replace client fixture with cluster fixture (#4159)

* replace client fixture with cluster fixture

* wait on persist before rebalance
parent b2d73dee
...@@ -28,7 +28,6 @@ import pandas as pd ...@@ -28,7 +28,6 @@ import pandas as pd
import sklearn.utils.estimator_checks as sklearn_checks import sklearn.utils.estimator_checks as sklearn_checks
from dask.array.utils import assert_eq from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster, default_client, wait from dask.distributed import Client, LocalCluster, default_client, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from pkg_resources import parse_version from pkg_resources import parse_version
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.stats import spearmanr from scipy.stats import spearmanr
...@@ -39,10 +38,6 @@ from .utils import make_ranking ...@@ -39,10 +38,6 @@ from .utils import make_ranking
sk_version = parse_version(sk_version) sk_version = parse_version(sk_version)
# time, in seconds, to wait for the Dask client to close. Used to avoid teardown errors
# see https://distributed.dask.org/en/latest/api.html#distributed.Client.close
CLIENT_CLOSE_TIMEOUT = 120
tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking'] tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking']
distributed_training_algorithms = ['data', 'voting'] distributed_training_algorithms = ['data', 'voting']
data_output = ['array', 'scipy_csr_matrix', 'dataframe', 'dataframe-with-categorical'] data_output = ['array', 'scipy_csr_matrix', 'dataframe', 'dataframe-with-categorical']
...@@ -68,6 +63,20 @@ pytestmark = [ ...@@ -68,6 +63,20 @@ pytestmark = [
] ]
@pytest.fixture(scope='module')
def cluster():
dask_cluster = LocalCluster(n_workers=2, threads_per_worker=2, dashboard_address=None)
yield dask_cluster
dask_cluster.close()
@pytest.fixture(scope='module')
def cluster2():
dask_cluster = LocalCluster(n_workers=2, threads_per_worker=2, dashboard_address=None)
yield dask_cluster
dask_cluster.close()
@pytest.fixture() @pytest.fixture()
def listen_port(): def listen_port():
listen_port.port += 10 listen_port.port += 10
...@@ -237,7 +246,8 @@ def _unpickle(filepath, serializer): ...@@ -237,7 +246,8 @@ def _unpickle(filepath, serializer):
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
@pytest.mark.parametrize('boosting_type', boosting_types) @pytest.mark.parametrize('boosting_type', boosting_types)
@pytest.mark.parametrize('tree_learner', distributed_training_algorithms) @pytest.mark.parametrize('tree_learner', distributed_training_algorithms)
def test_classifier(output, task, boosting_type, tree_learner, client): def test_classifier(output, task, boosting_type, tree_learner, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task, objective=task,
output=output output=output
...@@ -312,12 +322,11 @@ def test_classifier(output, task, boosting_type, tree_learner, client): ...@@ -312,12 +322,11 @@ def test_classifier(output, task, boosting_type, tree_learner, client):
assert node_uses_cat_col.sum() > 0 assert node_uses_cat_col.sum() > 0
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier_pred_contrib(output, task, client): def test_classifier_pred_contrib(output, task, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task, objective=task,
output=output output=output
...@@ -379,10 +388,9 @@ def test_classifier_pred_contrib(output, task, client): ...@@ -379,10 +388,9 @@ def test_classifier_pred_contrib(output, task, client):
base_value_col = num_features * (i + 1) + i base_value_col = num_features * (i + 1) + i
assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1) assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_find_random_open_port(cluster):
def test_find_random_open_port(client): with Client(cluster) as client:
for _ in range(5): for _ in range(5):
worker_address_to_port = client.run(lgb.dask._find_random_open_port) worker_address_to_port = client.run(lgb.dask._find_random_open_port)
found_ports = worker_address_to_port.values() found_ports = worker_address_to_port.values()
...@@ -392,11 +400,10 @@ def test_find_random_open_port(client): ...@@ -392,11 +400,10 @@ def test_find_random_open_port(client):
for port in found_ports: for port in found_ports:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port)) s.bind(('', port))
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_possibly_fix_worker_map(capsys, client): def test_possibly_fix_worker_map(capsys, cluster):
client.wait_for_workers(2) with Client(cluster) as client:
worker_addresses = list(client.scheduler_info()["workers"].keys()) worker_addresses = list(client.scheduler_info()["workers"].keys())
retry_msg = 'Searching for a LightGBM training port for worker' retry_msg = 'Searching for a LightGBM training port for worker'
...@@ -426,7 +433,8 @@ def test_possibly_fix_worker_map(capsys, client): ...@@ -426,7 +433,8 @@ def test_possibly_fix_worker_map(capsys, client):
assert len(set(patched_map.values())) == len(worker_addresses) assert len(set(patched_map.values())) == len(worker_addresses)
def test_training_does_not_fail_on_port_conflicts(client): def test_training_does_not_fail_on_port_conflicts(cluster):
with Client(cluster) as client:
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array') _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
lightgbm_default_port = 12400 lightgbm_default_port = 12400
...@@ -446,13 +454,12 @@ def test_training_does_not_fail_on_port_conflicts(client): ...@@ -446,13 +454,12 @@ def test_training_does_not_fail_on_port_conflicts(client):
) )
assert dask_classifier.booster_ assert dask_classifier.booster_
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('boosting_type', boosting_types) @pytest.mark.parametrize('boosting_type', boosting_types)
@pytest.mark.parametrize('tree_learner', distributed_training_algorithms) @pytest.mark.parametrize('tree_learner', distributed_training_algorithms)
def test_regressor(output, boosting_type, tree_learner, client): def test_regressor(output, boosting_type, tree_learner, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression', objective='regression',
output=output output=output
...@@ -523,11 +530,10 @@ def test_regressor(output, boosting_type, tree_learner, client): ...@@ -523,11 +530,10 @@ def test_regressor(output, boosting_type, tree_learner, client):
assert node_uses_cat_col.sum() > 0 assert node_uses_cat_col.sum() > 0
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_regressor_pred_contrib(output, client): def test_regressor_pred_contrib(output, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression', objective='regression',
output=output output=output
...@@ -572,12 +578,11 @@ def test_regressor_pred_contrib(output, client): ...@@ -572,12 +578,11 @@ def test_regressor_pred_contrib(output, client):
assert node_uses_cat_col.sum() > 0 assert node_uses_cat_col.sum() > 0
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9]) @pytest.mark.parametrize('alpha', [.1, .5, .9])
def test_regressor_quantile(output, client, alpha): def test_regressor_quantile(output, alpha, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data( X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression', objective='regression',
output=output output=output
...@@ -621,14 +626,13 @@ def test_regressor_quantile(output, client, alpha): ...@@ -621,14 +626,13 @@ def test_regressor_quantile(output, client, alpha):
assert node_uses_cat_col.sum() > 0 assert node_uses_cat_col.sum() > 0
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical']) @pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical'])
@pytest.mark.parametrize('group', [None, group_sizes]) @pytest.mark.parametrize('group', [None, group_sizes])
@pytest.mark.parametrize('boosting_type', boosting_types) @pytest.mark.parametrize('boosting_type', boosting_types)
@pytest.mark.parametrize('tree_learner', distributed_training_algorithms) @pytest.mark.parametrize('tree_learner', distributed_training_algorithms)
def test_ranker(output, group, boosting_type, tree_learner, client): def test_ranker(output, group, boosting_type, tree_learner, cluster):
with Client(cluster) as client:
if output == 'dataframe-with-categorical': if output == 'dataframe-with-categorical':
X, y, w, g, dX, dy, dw, dg = _create_data( X, y, w, g, dX, dy, dw, dg = _create_data(
objective='ranking', objective='ranking',
...@@ -714,11 +718,10 @@ def test_ranker(output, group, boosting_type, tree_learner, client): ...@@ -714,11 +718,10 @@ def test_ranker(output, group, boosting_type, tree_learner, client):
assert node_uses_cat_col.sum() > 0 assert node_uses_cat_col.sum() > 0
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
def test_training_works_if_client_not_provided_or_set_after_construction(task, client): def test_training_works_if_client_not_provided_or_set_after_construction(task, cluster):
with Client(cluster) as client:
_, _, _, _, dX, dy, _, dg = _create_data( _, _, _, _, dX, dy, _, dg = _create_data(
objective=task, objective=task,
output='array', output='array',
...@@ -778,15 +781,13 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c ...@@ -778,15 +781,13 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c
local_model.client local_model.client
local_model.client_ local_model.client_
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) @pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle'])
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('set_client', [True, False]) @pytest.mark.parametrize('set_client', [True, False])
def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, tmp_path): def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, tmp_path, cluster, cluster2):
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1, Client(cluster1) as client1: with Client(cluster) as client1:
# data on cluster1 # data on cluster1
X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_data( X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_data(
objective=task, objective=task,
...@@ -794,7 +795,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -794,7 +795,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
group=None group=None
) )
with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2, Client(cluster2) as client2: with Client(cluster2) as client2:
# create identical data on cluster2 # create identical data on cluster2
X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_data( X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_data(
objective=task, objective=task,
...@@ -948,7 +949,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -948,7 +949,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert_eq(preds_orig_local, preds_loaded_model_local) assert_eq(preds_orig_local, preds_loaded_model_local)
def test_warns_and_continues_on_unrecognized_tree_learner(client): def test_warns_and_continues_on_unrecognized_tree_learner(cluster):
with Client(cluster) as client:
X = da.random.random((1e3, 10)) X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1)) y = da.random.random((1e3, 1))
dask_regressor = lgb.DaskLGBMRegressor( dask_regressor = lgb.DaskLGBMRegressor(
...@@ -963,11 +965,10 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client): ...@@ -963,11 +965,10 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
assert dask_regressor.fitted_ assert dask_regressor.fitted_
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('tree_learner', ['data_parallel', 'voting_parallel']) @pytest.mark.parametrize('tree_learner', ['data_parallel', 'voting_parallel'])
def test_training_respects_tree_learner_aliases(tree_learner, client): def test_training_respects_tree_learner_aliases(tree_learner, cluster):
with Client(cluster) as client:
task = 'regression' task = 'regression'
_, _, _, _, dX, dy, dw, dg = _create_data(objective=task, output='array') _, _, _, _, dX, dy, dw, dg = _create_data(objective=task, output='array')
dask_factory = task_to_dask_factory[task] dask_factory = task_to_dask_factory[task]
...@@ -984,9 +985,13 @@ def test_training_respects_tree_learner_aliases(tree_learner, client): ...@@ -984,9 +985,13 @@ def test_training_respects_tree_learner_aliases(tree_learner, client):
assert dask_model.get_params()['tree_learner'] == tree_learner assert dask_model.get_params()['tree_learner'] == tree_learner
def test_error_on_feature_parallel_tree_learner(client): def test_error_on_feature_parallel_tree_learner(cluster):
with Client(cluster) as client:
X = da.random.random((100, 10), chunks=(50, 10)) X = da.random.random((100, 10), chunks=(50, 10))
y = da.random.random(100, chunks=50) y = da.random.random(100, chunks=50)
X, y = client.persist([X, y])
_ = wait([X, y])
client.rebalance()
dask_regressor = lgb.DaskLGBMRegressor( dask_regressor = lgb.DaskLGBMRegressor(
client=client, client=client,
time_out=5, time_out=5,
...@@ -997,19 +1002,17 @@ def test_error_on_feature_parallel_tree_learner(client): ...@@ -997,19 +1002,17 @@ def test_error_on_feature_parallel_tree_learner(client):
with pytest.raises(lgb.basic.LightGBMError, match='Do not support feature parallel in c api'): with pytest.raises(lgb.basic.LightGBMError, match='Do not support feature parallel in c api'):
dask_regressor = dask_regressor.fit(X, y) dask_regressor = dask_regressor.fit(X, y)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_errors(cluster):
@gen_cluster(client=True, timeout=None) with Client(cluster) as client:
def test_errors(c, s, a, b):
def f(part): def f(part):
raise Exception('foo') raise Exception('foo')
df = dd.demo.make_timeseries() df = dd.demo.make_timeseries()
df = df.map_partitions(f, meta=df._meta) df = df.map_partitions(f, meta=df._meta)
with pytest.raises(Exception) as info: with pytest.raises(Exception) as info:
yield lgb.dask._train( lgb.dask._train(
client=c, client=client,
data=df, data=df,
label=df.x, label=df.x,
params={}, params={},
...@@ -1020,10 +1023,11 @@ def test_errors(c, s, a, b): ...@@ -1020,10 +1023,11 @@ def test_errors(c, s, a, b):
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, task, output): def test_training_succeeds_even_if_some_workers_do_not_have_any_data(task, output, cluster):
if task == 'ranking' and output == 'scipy_csr_matrix': if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices') pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with Client(cluster) as client:
def collection_to_single_partition(collection): def collection_to_single_partition(collection):
"""Merge the parts of a Dask collection into a single partition.""" """Merge the parts of a Dask collection into a single partition."""
if collection is None: if collection is None:
...@@ -1069,13 +1073,10 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas ...@@ -1069,13 +1073,10 @@ def test_training_succeeds_even_if_some_workers_do_not_have_any_data(client, tas
assert assert_eq(dask_preds, local_preds) assert assert_eq(dask_preds, local_preds)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
def test_network_params_not_required_but_respected_if_given(client, task, listen_port): def test_network_params_not_required_but_respected_if_given(task, listen_port, cluster):
client.wait_for_workers(2) with Client(cluster) as client:
_, _, _, _, dX, dy, _, dg = _create_data( _, _, _, _, dX, dy, _, dg = _create_data(
objective=task, objective=task,
output='array', output='array',
...@@ -1129,12 +1130,10 @@ def test_network_params_not_required_but_respected_if_given(client, task, listen ...@@ -1129,12 +1130,10 @@ def test_network_params_not_required_but_respected_if_given(client, task, listen
with pytest.raises(lgb.basic.LightGBMError, match=error_msg): with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
dask_model3.fit(dX, dy, group=dg) dask_model3.fit(dX, dy, group=dg)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
def test_machines_should_be_used_if_provided(task): def test_machines_should_be_used_if_provided(task, cluster):
with LocalCluster(n_workers=2) as cluster, Client(cluster) as client: with Client(cluster) as client:
_, _, _, _, dX, dy, _, dg = _create_data( _, _, _, _, dX, dy, _, dg = _create_data(
objective=task, objective=task,
output='array', output='array',
...@@ -1167,6 +1166,9 @@ def test_machines_should_be_used_if_provided(task): ...@@ -1167,6 +1166,9 @@ def test_machines_should_be_used_if_provided(task):
s.bind(('127.0.0.1', open_ports[0])) s.bind(('127.0.0.1', open_ports[0]))
dask_model.fit(dX, dy, group=dg) dask_model.fit(dX, dy, group=dg)
# The above error leaves a worker waiting
client.restart()
# an informative error should be raised if "machines" has duplicates # an informative error should be raised if "machines" has duplicates
one_open_port = lgb.dask._find_random_open_port() one_open_port = lgb.dask._find_random_open_port()
dask_model.set_params( dask_model.set_params(
...@@ -1231,10 +1233,8 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods): ...@@ -1231,10 +1233,8 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods):
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array( def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task, cluster):
task, with Client(cluster) as client:
client,
):
_, _, _, _, dX, dy, dw, dg = _create_data( _, _, _, _, dX, dy, dw, dg = _create_data(
objective=task, objective=task,
output='dataframe', output='dataframe',
...@@ -1257,15 +1257,14 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array( ...@@ -1257,15 +1257,14 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(
model.fit(dX, dy_col_array, sample_weight=dw, group=dg) model.fit(dX, dy_col_array, sample_weight=dw, group=dg)
assert model.fitted_ assert model.fitted_
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_init_score(task, output, client): def test_init_score(task, output, cluster):
if task == 'ranking' and output == 'scipy_csr_matrix': if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices') pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with Client(cluster) as client:
_, _, _, _, dX, dy, dw, dg = _create_data( _, _, _, _, dX, dy, dw, dg = _create_data(
objective=task, objective=task,
output=output, output=output,
...@@ -1296,8 +1295,6 @@ def test_init_score(task, output, client): ...@@ -1296,8 +1295,6 @@ def test_init_score(task, output, client):
# value of the root node is 0 when init_score is set # value of the root node is 0 when init_score is set
assert model.booster_.trees_to_dataframe()['value'][0] == 0 assert model.booster_.trees_to_dataframe()['value'][0] == 0
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def sklearn_checks_to_run(): def sklearn_checks_to_run():
check_names = [ check_names = [
...@@ -1318,11 +1315,11 @@ def _tested_estimators(): ...@@ -1318,11 +1315,11 @@ def _tested_estimators():
@pytest.mark.parametrize("estimator", _tested_estimators()) @pytest.mark.parametrize("estimator", _tested_estimators())
@pytest.mark.parametrize("check", sklearn_checks_to_run()) @pytest.mark.parametrize("check", sklearn_checks_to_run())
def test_sklearn_integration(estimator, check, client): def test_sklearn_integration(estimator, check, cluster):
with Client(cluster) as client:
estimator.set_params(local_listen_port=18000, time_out=5) estimator.set_params(local_listen_port=18000, time_out=5)
name = type(estimator).__name__ name = type(estimator).__name__
check(name, estimator) check(name, estimator)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
# this test is separate because it takes a not-yet-constructed estimator # this test is separate because it takes a not-yet-constructed estimator
...@@ -1338,10 +1335,11 @@ def test_parameters_default_constructible(estimator): ...@@ -1338,10 +1335,11 @@ def test_parameters_default_constructible(estimator):
@pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
def test_predict_with_raw_score(task, output, client): def test_predict_with_raw_score(task, output, cluster):
if task == 'ranking' and output == 'scipy_csr_matrix': if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices') pytest.skip('LGBMRanker is not currently tested on sparse matrices')
with Client(cluster) as client:
_, _, _, _, dX, dy, _, dg = _create_data( _, _, _, _, dX, dy, _, dg = _create_data(
objective=task, objective=task,
output=output, output=output,
...@@ -1372,5 +1370,3 @@ def test_predict_with_raw_score(task, output, client): ...@@ -1372,5 +1370,3 @@ def test_predict_with_raw_score(task, output, client):
if task.endswith('classification'): if task.endswith('classification'):
pred_proba_raw = model.predict_proba(dX, raw_score=True).compute() pred_proba_raw = model.predict_proba(dX, raw_score=True).compute()
assert_eq(raw_predictions, pred_proba_raw) assert_eq(raw_predictions, pred_proba_raw)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment