Unverified Commit d4658fbb authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] add tests on warnings, fix incorrect variable in log (#3865)



* [dask] add tests on warnings, fix incorrect variable in log

* Update tests/python_package_test/test_dask.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 59153b28
...@@ -222,7 +222,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group ...@@ -222,7 +222,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
'voting_parallel' 'voting_parallel'
} }
if params["tree_learner"] not in allowed_tree_learners: if params["tree_learner"] not in allowed_tree_learners:
_log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % tree_learner) _log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % params['tree_learner'])
params['tree_learner'] = 'data' params['tree_learner'] = 'data'
if params['tree_learner'] not in {'data', 'data_parallel'}: if params['tree_learner'] not in {'data', 'data_parallel'}:
......
...@@ -430,6 +430,40 @@ def test_find_open_port_works(): ...@@ -430,6 +430,40 @@ def test_find_open_port_works():
assert new_port == 12402 assert new_port == 12402
def test_warns_and_continues_on_unrecognized_tree_learner(client):
X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1))
dask_regressor = lgb.DaskLGBMRegressor(
time_out=5,
local_listen_port=1234,
tree_learner='some-nonsense-value',
n_estimators=1,
num_leaves=2
)
with pytest.warns(UserWarning, match='Parameter tree_learner set to some-nonsense-value'):
dask_regressor = dask_regressor.fit(X, y, client=client)
assert dask_regressor.fitted_
def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1))
for tree_learner in ['feature_parallel', 'voting']:
dask_regressor = lgb.DaskLGBMRegressor(
time_out=5,
local_listen_port=1234,
tree_learner=tree_learner,
n_estimators=1,
num_leaves=2
)
with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner):
dask_regressor = dask_regressor.fit(X, y, client=client)
assert dask_regressor.fitted_
assert dask_regressor.get_params()['tree_learner'] == tree_learner
@gen_cluster(client=True, timeout=None) @gen_cluster(client=True, timeout=None)
def test_errors(c, s, a, b): def test_errors(c, s, a, b):
def f(part): def f(part):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment