Unverified Commit b161f334 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] fix trees_to_dataframe and enhance test (#2690)

* transfer and enhance test for trees_to_dataframe

* fixed bug in Python 2
parent c7e90393
...@@ -1901,7 +1901,7 @@ class Booster(object): ...@@ -1901,7 +1901,7 @@ class Booster(object):
return feature_name return feature_name
def _is_single_node_tree(tree): def _is_single_node_tree(tree):
return tree.keys() == {'leaf_value'} return set(tree.keys()) == {'leaf_value'}
# Create the node record, and populate universal data members # Create the node record, and populate universal data members
node = OrderedDict() node = OrderedDict()
......
...@@ -328,37 +328,3 @@ class TestBasic(unittest.TestCase): ...@@ -328,37 +328,3 @@ class TestBasic(unittest.TestCase):
lgb_data.set_weight(sequence) lgb_data.set_weight(sequence)
lgb_data.set_init_score(sequence) lgb_data.set_init_score(sequence)
check_asserts(lgb_data) check_asserts(lgb_data)
def test_trees_to_dataframe(self):
def _imptcs_to_numpy(X, impcts_dict):
cols = ['Column_' + str(i) for i in range(X.shape[1])]
imptcs = [impcts_dict.get(col, 0.) for col in cols]
return np.array(imptcs)
X, y = load_breast_cancer(True)
data = lgb.Dataset(X, label=y)
num_trees = 10
bst = lgb.train({"objective": "binary"}, data, num_trees)
tree_df = bst.trees_to_dataframe()
split_dict = (tree_df[~tree_df['split_gain'].isnull()]
.groupby('split_feature')
.size()
.to_dict())
gains_dict = (tree_df
.groupby('split_feature')['split_gain']
.sum()
.to_dict())
tree_split = _imptcs_to_numpy(X, split_dict)
tree_gains = _imptcs_to_numpy(X, gains_dict)
mod_split = bst.feature_importance('split')
mod_gains = bst.feature_importance('gain')
num_trees_from_df = tree_df['tree_index'].nunique()
obs_counts_from_df = tree_df.loc[tree_df['node_depth'] == 1, 'count'].values
np.testing.assert_equal(tree_split, mod_split)
np.testing.assert_allclose(tree_gains, mod_gains)
self.assertEqual(num_trees_from_df, num_trees)
np.testing.assert_equal(obs_counts_from_df, len(y))
...@@ -1810,3 +1810,54 @@ class TestEngine(unittest.TestCase): ...@@ -1810,3 +1810,54 @@ class TestEngine(unittest.TestCase):
predicted = est.predict(new_x) predicted = est.predict(new_x)
self.assertNotAlmostEqual(predicted[0], predicted[1]) self.assertNotAlmostEqual(predicted[0], predicted[1])
self.assertAlmostEqual(predicted[1], predicted[2]) self.assertAlmostEqual(predicted[1], predicted[2])
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_trees_to_dataframe(self):
def _imptcs_to_numpy(X, impcts_dict):
cols = ['Column_' + str(i) for i in range(X.shape[1])]
return [impcts_dict.get(col, 0.) for col in cols]
X, y = load_breast_cancer(True)
data = lgb.Dataset(X, label=y)
num_trees = 10
bst = lgb.train({"objective": "binary", "verbose": -1}, data, num_trees)
tree_df = bst.trees_to_dataframe()
split_dict = (tree_df[~tree_df['split_gain'].isnull()]
.groupby('split_feature')
.size()
.to_dict())
gains_dict = (tree_df
.groupby('split_feature')['split_gain']
.sum()
.to_dict())
tree_split = _imptcs_to_numpy(X, split_dict)
tree_gains = _imptcs_to_numpy(X, gains_dict)
mod_split = bst.feature_importance('split')
mod_gains = bst.feature_importance('gain')
num_trees_from_df = tree_df['tree_index'].nunique()
obs_counts_from_df = tree_df.loc[tree_df['node_depth'] == 1, 'count'].values
np.testing.assert_equal(tree_split, mod_split)
np.testing.assert_allclose(tree_gains, mod_gains)
self.assertEqual(num_trees_from_df, num_trees)
np.testing.assert_equal(obs_counts_from_df, len(y))
# test edge case with one leaf
X = np.ones((10, 2))
y = np.random.rand(10)
data = lgb.Dataset(X, label=y)
bst = lgb.train({"objective": "binary", "verbose": -1}, data, num_trees)
tree_df = bst.trees_to_dataframe()
self.assertEqual(len(tree_df), 1)
self.assertEqual(tree_df.loc[0, 'tree_index'], 0)
self.assertEqual(tree_df.loc[0, 'node_depth'], 1)
self.assertEqual(tree_df.loc[0, 'node_index'], "0-L0")
self.assertIsNotNone(tree_df.loc[0, 'value'])
for col in ('left_child', 'right_child', 'parent_index', 'split_feature',
'split_gain', 'threshold', 'decision_type', 'missing_direction',
'missing_type', 'weight', 'count'):
self.assertIsNone(tree_df.loc[0, col])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment