Unverified Commit d05f5470 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[tests][python] added tests for early stop in prediction in ranking task (#4457)

parent 0d1d12fb
...@@ -747,6 +747,13 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster): ...@@ -747,6 +747,13 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster):
p1_pred_leaf = dask_ranker.predict(dX, pred_leaf=True) p1_pred_leaf = dask_ranker.predict(dX, pred_leaf=True)
p1_raw = dask_ranker.predict(dX, raw_score=True).compute() p1_raw = dask_ranker.predict(dX, raw_score=True).compute()
p1_first_iter_raw = dask_ranker.predict(dX, start_iteration=0, num_iteration=1, raw_score=True).compute() p1_first_iter_raw = dask_ranker.predict(dX, start_iteration=0, num_iteration=1, raw_score=True).compute()
p1_early_stop_raw = dask_ranker.predict(
dX,
pred_early_stop=True,
pred_early_stop_margin=1.0,
pred_early_stop_freq=2,
raw_score=True
).compute()
rnkvec_dask_local = dask_ranker.to_local().predict(X) rnkvec_dask_local = dask_ranker.to_local().predict(X)
local_ranker = lgb.LGBMRanker(**params) local_ranker = lgb.LGBMRanker(**params)
...@@ -764,6 +771,9 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster): ...@@ -764,6 +771,9 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert_eq(p1_raw, p1_first_iter_raw) assert_eq(p1_raw, p1_first_iter_raw)
with pytest.raises(AssertionError):
assert_eq(p1_raw, p1_early_stop_raw)
# pref_leaf values should have the right shape # pref_leaf values should have the right shape
# and values that look like valid tree nodes # and values that look like valid tree nodes
pred_leaf_vals = p1_pred_leaf.compute() pred_leaf_vals = p1_pred_leaf.compute()
......
...@@ -455,9 +455,7 @@ def test_multiclass_prediction_early_stopping(): ...@@ -455,9 +455,7 @@ def test_multiclass_prediction_early_stopping():
assert ret < 0.8 assert ret < 0.8
assert ret > 0.6 # loss will be higher than when evaluating the full model assert ret > 0.6 # loss will be higher than when evaluating the full model
pred_parameter = {"pred_early_stop": True, pred_parameter["pred_early_stop_margin"] = 5.5
"pred_early_stop_freq": 5,
"pred_early_stop_margin": 5.5}
ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter)) ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter))
assert ret < 0.2 assert ret < 0.2
...@@ -588,6 +586,29 @@ def test_auc_mu(): ...@@ -588,6 +586,29 @@ def test_auc_mu():
assert results_weight['training']['auc_mu'][-1] != results_no_weight['training']['auc_mu'][-1] assert results_weight['training']['auc_mu'][-1] != results_no_weight['training']['auc_mu'][-1]
def test_ranking_prediction_early_stopping():
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query'))
X_test, _ = load_svmlight_file(str(rank_example_dir / 'rank.test'))
params = {
'objective': 'rank_xendcg',
'verbose': -1
}
lgb_train = lgb.Dataset(X_train, y_train, group=q_train, params=params)
gbm = lgb.train(params, lgb_train, num_boost_round=50)
pred_parameter = {"pred_early_stop": True,
"pred_early_stop_freq": 5,
"pred_early_stop_margin": 1.5}
ret_early = gbm.predict(X_test, **pred_parameter)
pred_parameter["pred_early_stop_margin"] = 5.5
ret_early_more_strict = gbm.predict(X_test, **pred_parameter)
with pytest.raises(AssertionError):
np.testing.assert_allclose(ret_early, ret_early_more_strict)
def test_early_stopping(): def test_early_stopping():
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
params = { params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment