Unverified Commit c871496a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] reduce test times (#3786)

* speed up tests

* [dask] reduce test times
parent d2c55454
......@@ -72,22 +72,29 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size
def test_classifier(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)
dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
n_estimators=10,
num_leaves=10
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict(dX)
p1_proba = dask_classifier.predict_proba(dX).compute()
s1 = accuracy_score(dy, p1)
p1 = p1.compute()
local_classifier = lightgbm.LGBMClassifier()
local_classifier = lightgbm.LGBMClassifier(n_estimators=10, num_leaves=10)
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict(X)
p2_proba = local_classifier.predict_proba(X)
s2 = local_classifier.score(X, y)
assert_eq(s1, s2)
assert_eq(p1, p2)
assert_eq(y, p1)
assert_eq(y, p2)
assert_eq(p1_proba, p2_proba, atol=0.3)
def test_training_does_not_fail_on_port_conflicts(client):
......@@ -98,7 +105,9 @@ def test_training_does_not_fail_on_port_conflicts(client):
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=12400
local_listen_port=12400,
n_estimators=5,
num_leaves=5
)
for i in range(5):
dask_classifier.fit(
......@@ -110,31 +119,19 @@ def test_training_does_not_fail_on_port_conflicts(client):
assert dask_classifier.booster_
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_proba(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)
dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict_proba(dX)
p1 = p1.compute()
local_classifier = lightgbm.LGBMClassifier()
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict_proba(X)
assert_eq(p1, p2, atol=0.3)
def test_classifier_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output='array')
dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
n_estimators=10,
num_leaves=10
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.to_local().predict(dX)
local_classifier = lightgbm.LGBMClassifier()
local_classifier = lightgbm.LGBMClassifier(n_estimators=10, num_leaves=10)
local_classifier.fit(X, y, sample_weight=w)
p2 = local_classifier.predict(X)
......@@ -147,14 +144,19 @@ def test_classifier_local_predict(client, listen_port):
def test_regressor(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output)
dask_regressor = dlgbm.DaskLGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42)
dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
local_listen_port=listen_port,
seed=42,
num_leaves=10
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
if output != 'dataframe':
s1 = r2_score(dy, p1)
p1 = p1.compute()
local_regressor = lightgbm.LGBMRegressor(seed=42)
local_regressor = lightgbm.LGBMRegressor(seed=42, num_leaves=10)
local_regressor.fit(X, y, sample_weight=w)
s2 = local_regressor.score(X, y)
p2 = local_regressor.predict(X)
......@@ -173,12 +175,25 @@ def test_regressor(output, client, listen_port):
def test_regressor_quantile(output, client, listen_port, alpha):
X, y, w, dX, dy, dw = _create_data('regression', output=output)
dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha)
dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42,
objective='quantile',
alpha=alpha,
n_estimators=10,
num_leaves=10
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
q1 = np.count_nonzero(y < p1) / y.shape[0]
local_regressor = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha)
local_regressor = lightgbm.LGBMRegressor(
seed=42,
objective='quantile',
alpha=alpha,
n_estimatores=10,
num_leaves=10
)
local_regressor.fit(X, y, sample_weight=w)
p2 = local_regressor.predict(X)
q2 = np.count_nonzero(y < p2) / y.shape[0]
......@@ -191,7 +206,12 @@ def test_regressor_quantile(output, client, listen_port, alpha):
def test_regressor_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output='array')
dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42)
dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42,
n_estimators=10,
num_leaves=10
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_regressor.predict(dX)
p2 = dask_regressor.to_local().predict(X)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment