Unverified Commit bca2da97 authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Interaction constraints (#3126)

* Add interaction constraints functionality.

* Minor fixes.

* Minor fixes.

* Change lambda to function.

* Fix gpu bug, remove extra blank lines.

* Fix gpu bug.

* Fix style issues.

* Try to fix segfault on MACOS.

* Fix bug.

* Fix bug.

* Fix bugs.

* Change parameter format for R.

* Fix R style issues.

* Change string formatting code.

* Change docs to say R package not supported.

* Remove R functionality, moving to separate PR.

* Keep track of branch features in tree object.

* Only track branch features when feature interactions are enabled.

* Fix lint error.

* Update docs and simplify tests.
parent f5e51649
...@@ -2185,3 +2185,28 @@ class TestEngine(unittest.TestCase): ...@@ -2185,3 +2185,28 @@ class TestEngine(unittest.TestCase):
'split_gain', 'threshold', 'decision_type', 'missing_direction', 'split_gain', 'threshold', 'decision_type', 'missing_direction',
'missing_type', 'weight', 'count'): 'missing_type', 'weight', 'count'):
self.assertIsNone(tree_df.loc[0, col]) self.assertIsNone(tree_df.loc[0, col])
def test_interaction_constraints(self):
X, y = load_boston(True)
num_features = X.shape[1]
train_data = lgb.Dataset(X, label=y)
# check that constraint containing all features is equivalent to no constraint
params = {'verbose': -1,
'seed': 0}
est = lgb.train(params, train_data, num_boost_round=10)
pred1 = est.predict(X)
est = lgb.train(dict(params, interation_constraints=[list(range(num_features))]), train_data,
num_boost_round=10)
pred2 = est.predict(X)
np.testing.assert_allclose(pred1, pred2)
# check that constraint partitioning the features reduces train accuracy
est = lgb.train(dict(params, interaction_constraints=[list(range(num_features // 2)),
list(range(num_features // 2, num_features))]),
train_data, num_boost_round=10)
pred3 = est.predict(X)
self.assertLess(mean_squared_error(y, pred1), mean_squared_error(y, pred3))
# check that constraints consisting of single features reduce accuracy further
est = lgb.train(dict(params, interaction_constraints=[[i] for i in range(num_features)]), train_data,
num_boost_round=10)
pred4 = est.predict(X)
self.assertLess(mean_squared_error(y, pred3), mean_squared_error(y, pred4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment