Commit 301402c8 authored by Patrick Ford's avatar Patrick Ford Committed by Nikita Titov
Browse files

[python] Output model to a pandas DataFrame (#2592)

* trees_to_df method and unit test added. PEP 8 fixes for integration.

* Co-Authored-By: Nikita Titov <nekit94-08@mail.ru>

Post-review changes

* changes from second round of reviews from striker

* third round of review. formatting and added 2 more tests

* replaced pandas dot attribute accessor with string attribute accessor

* dealt with single tree edge case and minor refactor of tests

* slight refactor for checking if tree is a single node
parent f6b8ecf6
...@@ -7,6 +7,7 @@ import ctypes ...@@ -7,6 +7,7 @@ import ctypes
import os import os
import warnings import warnings
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from collections import OrderedDict
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
...@@ -1858,6 +1859,121 @@ class Booster(object): ...@@ -1858,6 +1859,121 @@ class Booster(object):
self.network = False self.network = False
return self return self
def trees_to_dataframe(self):
"""Parse the fitted model and return in an easy-to-read pandas DataFrame.
Returns
-------
result : pandas DataFrame
Returns a pandas DataFrame of the parsed model.
"""
if not PANDAS_INSTALLED:
raise LightGBMError('This method cannot be run without pandas installed')
if self.num_trees() == 0:
raise LightGBMError('There are no trees in this Booster and thus nothing to parse')
def _is_split_node(tree):
return 'split_index' in tree.keys()
def create_node_record(tree, node_depth=1, tree_index=None,
feature_names=None, parent_node=None):
def _get_node_index(tree, tree_index):
tree_num = str(tree_index) + '-' if tree_index is not None else ''
is_split = _is_split_node(tree)
node_type = 'S' if is_split else 'L'
# if a single node tree it won't have `leaf_index` so return 0
node_num = str(tree.get('split_index' if is_split else 'leaf_index', 0))
return tree_num + node_type + node_num
def _get_split_feature(tree, feature_names):
if _is_split_node(tree):
if feature_names is not None:
feature_name = feature_names[tree['split_feature']]
else:
feature_name = tree['split_feature']
else:
feature_name = None
return feature_name
def _is_single_node_tree(tree):
return tree.keys() == {'leaf_value'}
# Create the node record, and populate universal data members
node = OrderedDict()
node['tree_index'] = tree_index
node['node_depth'] = node_depth
node['node_index'] = _get_node_index(tree, tree_index)
node['left_child'] = None
node['right_child'] = None
node['parent_index'] = parent_node
node['split_feature'] = _get_split_feature(tree, feature_names)
node['split_gain'] = None
node['threshold'] = None
node['decision_type'] = None
node['missing_direction'] = None
node['missing_type'] = None
node['value'] = None
node['weight'] = None
node['count'] = None
# Update values to reflect node type (leaf or split)
if _is_split_node(tree):
node['left_child'] = _get_node_index(tree['left_child'], tree_index)
node['right_child'] = _get_node_index(tree['right_child'], tree_index)
node['split_gain'] = tree['split_gain']
node['threshold'] = tree['threshold']
node['decision_type'] = tree['decision_type']
node['missing_direction'] = 'left' if tree['default_left'] else 'right'
node['missing_type'] = tree['missing_type']
node['value'] = tree['internal_value']
node['weight'] = tree['internal_weight']
node['count'] = tree['internal_count']
else:
node['value'] = tree['leaf_value']
if not _is_single_node_tree(tree):
node['weight'] = tree['leaf_weight']
node['count'] = tree['leaf_count']
return node
def tree_dict_to_node_list(tree, node_depth=1, tree_index=None,
feature_names=None, parent_node=None):
node = create_node_record(tree,
node_depth=node_depth,
tree_index=tree_index,
feature_names=feature_names,
parent_node=parent_node)
res = [node]
if _is_split_node(tree):
# traverse the next level of the tree
children = ['left_child', 'right_child']
for child in children:
subtree_list = tree_dict_to_node_list(
tree[child],
node_depth=node_depth + 1,
tree_index=tree_index,
feature_names=feature_names,
parent_node=node['node_index'])
# In tree format, "subtree_list" is a list of node records (dicts),
# and we add node to the list.
res.extend(subtree_list)
return res
model_dict = self.dump_model()
feature_names = model_dict['feature_names']
model_list = []
for tree in model_dict['tree_info']:
model_list.extend(tree_dict_to_node_list(tree['tree_structure'],
tree_index=tree['tree_index'],
feature_names=feature_names))
return DataFrame(model_list, columns=model_list[0].keys())
def set_train_data_name(self, name): def set_train_data_name(self, name):
"""Set the name to the training Dataset. """Set the name to the training Dataset.
......
...@@ -328,3 +328,37 @@ class TestBasic(unittest.TestCase): ...@@ -328,3 +328,37 @@ class TestBasic(unittest.TestCase):
lgb_data.set_weight(sequence) lgb_data.set_weight(sequence)
lgb_data.set_init_score(sequence) lgb_data.set_init_score(sequence)
check_asserts(lgb_data) check_asserts(lgb_data)
def test_trees_to_dataframe(self):
def _imptcs_to_numpy(X, impcts_dict):
cols = ['Column_' + str(i) for i in range(X.shape[1])]
imptcs = [impcts_dict.get(col, 0.) for col in cols]
return np.array(imptcs)
X, y = load_breast_cancer(True)
data = lgb.Dataset(X, label=y)
num_trees = 10
bst = lgb.train({"objective": "binary"}, data, num_trees)
tree_df = bst.trees_to_dataframe()
split_dict = (tree_df[~tree_df['split_gain'].isnull()]
.groupby('split_feature')
.size()
.to_dict())
gains_dict = (tree_df
.groupby('split_feature')['split_gain']
.sum()
.to_dict())
tree_split = _imptcs_to_numpy(X, split_dict)
tree_gains = _imptcs_to_numpy(X, gains_dict)
mod_split = bst.feature_importance('split')
mod_gains = bst.feature_importance('gain')
num_trees_from_df = tree_df['tree_index'].nunique()
obs_counts_from_df = tree_df.loc[tree_df['node_depth'] == 1, 'count'].values
np.testing.assert_equal(tree_split, mod_split)
np.testing.assert_allclose(tree_gains, mod_gains)
self.assertEqual(num_trees_from_df, num_trees)
np.testing.assert_equal(obs_counts_from_df, len(y))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment