Unverified Commit 8116d880 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] pass additional predict() parameters through when input is a Dask Array (#4399)



* [dask] pass predict() kwargs through when input is a Dask Array

* add tests

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* add prediction early stopping params
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 558e4a40
...@@ -570,7 +570,8 @@ def _predict( ...@@ -570,7 +570,8 @@ def _predict(
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
dtype=dtype, dtype=dtype,
drop_axis=1 drop_axis=1,
**kwargs
) )
else: else:
raise TypeError(f'Data must be either Dask Array or Dask DataFrame. Got {type(data)}.') raise TypeError(f'Data must be either Dask Array or Dask DataFrame. Got {type(data)}.')
......
...@@ -272,6 +272,15 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster): ...@@ -272,6 +272,15 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster):
) )
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
p1 = dask_classifier.predict(dX) p1 = dask_classifier.predict(dX)
p1_raw = dask_classifier.predict(dX, raw_score=True).compute()
p1_first_iter_raw = dask_classifier.predict(dX, start_iteration=0, num_iteration=1, raw_score=True).compute()
p1_early_stop_raw = dask_classifier.predict(
dX,
pred_early_stop=True,
pred_early_stop_margin=1.0,
pred_early_stop_freq=2,
raw_score=True
)
p1_proba = dask_classifier.predict_proba(dX).compute() p1_proba = dask_classifier.predict_proba(dX).compute()
p1_pred_leaf = dask_classifier.predict(dX, pred_leaf=True) p1_pred_leaf = dask_classifier.predict(dX, pred_leaf=True)
p1_local = dask_classifier.to_local().predict(X) p1_local = dask_classifier.to_local().predict(X)
...@@ -297,6 +306,13 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster): ...@@ -297,6 +306,13 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster):
assert_eq(p1_local, p2) assert_eq(p1_local, p2)
assert_eq(p1_local, y) assert_eq(p1_local, y)
# extra predict() parameters should be passed through correctly
with pytest.raises(AssertionError):
assert_eq(p1_raw, p1_first_iter_raw)
with pytest.raises(AssertionError):
assert_eq(p1_raw, p1_early_stop_raw)
# pref_leaf values should have the right shape # pref_leaf values should have the right shape
# and values that look like valid tree nodes # and values that look like valid tree nodes
pred_leaf_vals = p1_pred_leaf.compute() pred_leaf_vals = p1_pred_leaf.compute()
...@@ -487,6 +503,8 @@ def test_regressor(output, boosting_type, tree_learner, cluster): ...@@ -487,6 +503,8 @@ def test_regressor(output, boosting_type, tree_learner, cluster):
s1 = _r2_score(dy, p1) s1 = _r2_score(dy, p1)
p1 = p1.compute() p1 = p1.compute()
p1_raw = dask_regressor.predict(dX, raw_score=True).compute()
p1_first_iter_raw = dask_regressor.predict(dX, start_iteration=0, num_iteration=1, raw_score=True).compute()
p1_local = dask_regressor.to_local().predict(X) p1_local = dask_regressor.to_local().predict(X)
s1_local = dask_regressor.to_local().score(X, y) s1_local = dask_regressor.to_local().score(X, y)
...@@ -516,6 +534,10 @@ def test_regressor(output, boosting_type, tree_learner, cluster): ...@@ -516,6 +534,10 @@ def test_regressor(output, boosting_type, tree_learner, cluster):
assert_eq(p1, y, rtol=0.5, atol=50.) assert_eq(p1, y, rtol=0.5, atol=50.)
assert_eq(p2, y, rtol=0.5, atol=50.) assert_eq(p2, y, rtol=0.5, atol=50.)
# extra predict() parameters should be passed through correctly
with pytest.raises(AssertionError):
assert_eq(p1_raw, p1_first_iter_raw)
# be sure LightGBM actually used at least one categorical column, # be sure LightGBM actually used at least one categorical column,
# and that it was correctly treated as a categorical feature # and that it was correctly treated as a categorical feature
if output == 'dataframe-with-categorical': if output == 'dataframe-with-categorical':
...@@ -680,6 +702,8 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster): ...@@ -680,6 +702,8 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster):
rnkvec_dask = dask_ranker.predict(dX) rnkvec_dask = dask_ranker.predict(dX)
rnkvec_dask = rnkvec_dask.compute() rnkvec_dask = rnkvec_dask.compute()
p1_pred_leaf = dask_ranker.predict(dX, pred_leaf=True) p1_pred_leaf = dask_ranker.predict(dX, pred_leaf=True)
p1_raw = dask_ranker.predict(dX, raw_score=True).compute()
p1_first_iter_raw = dask_ranker.predict(dX, start_iteration=0, num_iteration=1, raw_score=True).compute()
rnkvec_dask_local = dask_ranker.to_local().predict(X) rnkvec_dask_local = dask_ranker.to_local().predict(X)
local_ranker = lgb.LGBMRanker(**params) local_ranker = lgb.LGBMRanker(**params)
...@@ -693,6 +717,10 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster): ...@@ -693,6 +717,10 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster):
assert spearmanr(rnkvec_dask, rnkvec_local).correlation > 0.8 assert spearmanr(rnkvec_dask, rnkvec_local).correlation > 0.8
assert_eq(rnkvec_dask, rnkvec_dask_local) assert_eq(rnkvec_dask, rnkvec_dask_local)
# extra predict() parameters should be passed through correctly
with pytest.raises(AssertionError):
assert_eq(p1_raw, p1_first_iter_raw)
# pref_leaf values should have the right shape # pref_leaf values should have the right shape
# and values that look like valid tree nodes # and values that look like valid tree nodes
pred_leaf_vals = p1_pred_leaf.compute() pred_leaf_vals = p1_pred_leaf.compute()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment