Unverified Commit d563aff9 authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Fix bug with interaction constraints (#3189)

* Fix bug: crashes when interaction_constraints is nonempty and not all features are used.

* Fix python lint error.
parent 72849466
...@@ -117,7 +117,9 @@ class ColSampler { ...@@ -117,7 +117,9 @@ class ColSampler {
} else { } else {
for (int feat : allowed_features) { for (int feat : allowed_features) {
int inner_feat = train_data_->InnerFeatureIndex(feat); int inner_feat = train_data_->InnerFeatureIndex(feat);
ret[inner_feat] = 1; if (inner_feat >= 0) {
ret[inner_feat] = 1;
}
} }
return ret; return ret;
} }
......
...@@ -2195,7 +2195,7 @@ class TestEngine(unittest.TestCase): ...@@ -2195,7 +2195,7 @@ class TestEngine(unittest.TestCase):
'seed': 0} 'seed': 0}
est = lgb.train(params, train_data, num_boost_round=10) est = lgb.train(params, train_data, num_boost_round=10)
pred1 = est.predict(X) pred1 = est.predict(X)
est = lgb.train(dict(params, interation_constraints=[list(range(num_features))]), train_data, est = lgb.train(dict(params, interaction_constraints=[list(range(num_features))]), train_data,
num_boost_round=10) num_boost_round=10)
pred2 = est.predict(X) pred2 = est.predict(X)
np.testing.assert_allclose(pred1, pred2) np.testing.assert_allclose(pred1, pred2)
...@@ -2210,3 +2210,10 @@ class TestEngine(unittest.TestCase): ...@@ -2210,3 +2210,10 @@ class TestEngine(unittest.TestCase):
num_boost_round=10) num_boost_round=10)
pred4 = est.predict(X) pred4 = est.predict(X)
self.assertLess(mean_squared_error(y, pred3), mean_squared_error(y, pred4)) self.assertLess(mean_squared_error(y, pred3), mean_squared_error(y, pred4))
# test that interaction constraints work when not all features are used
X = np.concatenate([np.zeros((X.shape[0], 1)), X], axis=1)
num_features = X.shape[1]
train_data = lgb.Dataset(X, label=y)
est = lgb.train(dict(params, interaction_constraints=[[0] + list(range(2, num_features)),
[1] + list(range(2, num_features))]),
train_data, num_boost_round=10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment