Commit 69a36605 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] make tree rendering more clear (#1424)

* fixed grammar

* fixed params description in graph plotting functions

* clarified types of attributes in their descriptions

* increased readability of graphs by adding spaces

* added precision parameter to plot tree functions
parent 42d5e571
...@@ -83,7 +83,7 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -83,7 +83,7 @@ def plot_importance(booster, ax=None, height=0.2,
feature_name = booster.feature_name() feature_name = booster.feature_name()
if not len(importance): if not len(importance):
raise ValueError('Booster feature_importances are empty.') raise ValueError("Booster's feature_importance is empty.")
tuples = sorted(zip(feature_name, importance), key=lambda x: x[1]) tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
if ignore_zero: if ignore_zero:
...@@ -252,7 +252,7 @@ def plot_metric(booster, metric=None, dataset_names=None, ...@@ -252,7 +252,7 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax return ax
def _to_graphviz(tree_info, show_info, feature_names, def _to_graphviz(tree_info, show_info, feature_names, precision=None,
name=None, comment=None, filename=None, directory=None, name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None, format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False): node_attr=None, edge_attr=None, body=None, strict=False):
...@@ -266,18 +266,23 @@ def _to_graphviz(tree_info, show_info, feature_names, ...@@ -266,18 +266,23 @@ def _to_graphviz(tree_info, show_info, feature_names,
except ImportError: except ImportError:
raise ImportError('You must install graphviz to plot tree.') raise ImportError('You must install graphviz to plot tree.')
def float2str(value, precision=None):
return "{0:.{1}f}".format(value, precision) if precision is not None else str(value)
def add(root, parent=None, decision=None): def add(root, parent=None, decision=None):
"""recursively add node or edge""" """recursively add node or edge"""
if 'split_index' in root: # non-leaf if 'split_index' in root: # non-leaf
name = 'split' + str(root['split_index']) name = 'split{0}'.format(root['split_index'])
if feature_names is not None: if feature_names is not None:
label = 'split_feature_name:' + str(feature_names[root['split_feature']]) label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']])
else: else:
label = 'split_feature_index:' + str(root['split_feature']) label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold:' + str(root['threshold']) label += r'\nthreshold: {0}'.format(float2str(root['threshold'], precision))
for info in show_info: for info in show_info:
if info in {'split_gain', 'internal_value', 'internal_count'}: if info in {'split_gain', 'internal_value'}:
label += r'\n' + info + ':' + str(root[info]) label += r'\n{0}: {1}'.format(info, float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
graph.node(name, label=label) graph.node(name, label=label)
if root['decision_type'] == '<=': if root['decision_type'] == '<=':
l_dec, r_dec = '<=', '>' l_dec, r_dec = '<=', '>'
...@@ -288,11 +293,11 @@ def _to_graphviz(tree_info, show_info, feature_names, ...@@ -288,11 +293,11 @@ def _to_graphviz(tree_info, show_info, feature_names,
add(root['left_child'], name, l_dec) add(root['left_child'], name, l_dec)
add(root['right_child'], name, r_dec) add(root['right_child'], name, r_dec)
else: # leaf else: # leaf
name = 'leaf' + str(root['leaf_index']) name = 'leaf{0}'.format(root['leaf_index'])
label = 'leaf_index:' + str(root['leaf_index']) label = 'leaf_index: {0}'.format(root['leaf_index'])
label += r'\nleaf_value:' + str(root['leaf_value']) label += r'\nleaf_value: {0}'.format(float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info: if 'leaf_count' in show_info:
label += r'\nleaf_count:' + str(root['leaf_count']) label += r'\nleaf_count: {0}'.format(root['leaf_count'])
graph.node(name, label=label) graph.node(name, label=label)
if parent is not None: if parent is not None:
graph.edge(parent, name, decision) graph.edge(parent, name, decision)
...@@ -305,7 +310,7 @@ def _to_graphviz(tree_info, show_info, feature_names, ...@@ -305,7 +310,7 @@ def _to_graphviz(tree_info, show_info, feature_names,
return graph return graph
def create_tree_digraph(booster, tree_index=0, show_info=None, def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
name=None, comment=None, filename=None, directory=None, name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None, format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False): node_attr=None, edge_attr=None, body=None, strict=False):
...@@ -322,9 +327,11 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, ...@@ -322,9 +327,11 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
Booster or LGBMModel instance. Booster or LGBMModel instance.
tree_index : int, optional (default=0) tree_index : int, optional (default=0)
The index of a target tree to convert. The index of a target tree to convert.
show_info : list or None, optional (default=None) show_info : list of strings or None, optional (default=None)
What information should be showed on nodes. What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
name : string or None, optional (default=None) name : string or None, optional (default=None)
Graph name used in the source code. Graph name used in the source code.
comment : string or None, optional (default=None) comment : string or None, optional (default=None)
...@@ -340,12 +347,15 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, ...@@ -340,12 +347,15 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
Layout command used ('dot', 'neato', ...). Layout command used ('dot', 'neato', ...).
encoding : string or None, optional (default=None) encoding : string or None, optional (default=None)
Encoding for saving the source. Encoding for saving the source.
graph_attr : dict or None, optional (default=None) graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph. Mapping of (attribute, value) pairs set for the graph.
node_attr : dict or None, optional (default=None) All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes. Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict or None, optional (default=None) All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges. Mapping of (attribute, value) pairs set for all edges.
All attributes and values must be strings or bytes-like objects.
body : list of strings or None, optional (default=None) body : list of strings or None, optional (default=None)
Lines to add to the graph body. Lines to add to the graph body.
strict : bool, optional (default=False) strict : bool, optional (default=False)
...@@ -376,7 +386,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, ...@@ -376,7 +386,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
if show_info is None: if show_info is None:
show_info = [] show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names, graph = _to_graphviz(tree_info, show_info, feature_names, precision,
name=name, comment=comment, filename=filename, directory=directory, name=name, comment=comment, filename=filename, directory=directory,
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr, format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict) node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
...@@ -386,7 +396,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, ...@@ -386,7 +396,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
def plot_tree(booster, ax=None, tree_index=0, figsize=None, def plot_tree(booster, ax=None, tree_index=0, figsize=None,
graph_attr=None, node_attr=None, edge_attr=None, graph_attr=None, node_attr=None, edge_attr=None,
show_info=None): show_info=None, precision=None):
"""Plot specified tree. """Plot specified tree.
Parameters Parameters
...@@ -400,15 +410,20 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -400,15 +410,20 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
The index of a target tree to plot. The index of a target tree to plot.
figsize : tuple of 2 elements or None, optional (default=None) figsize : tuple of 2 elements or None, optional (default=None)
Figure size. Figure size.
graph_attr : dict or None, optional (default=None) graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph. Mapping of (attribute, value) pairs set for the graph.
node_attr : dict or None, optional (default=None) All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes. Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict or None, optional (default=None) All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges. Mapping of (attribute, value) pairs set for all edges.
show_info : list or None, optional (default=None) All attributes and values must be strings or bytes-like objects.
What information should be showed on nodes. show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
Returns Returns
------- -------
...@@ -429,10 +444,11 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -429,10 +444,11 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
graph = create_tree_digraph( graph = create_tree_digraph(
booster=booster, booster=booster,
tree_index=tree_index, tree_index=tree_index,
show_info=show_info,
precision=precision,
graph_attr=graph_attr, graph_attr=graph_attr,
node_attr=node_attr, node_attr=node_attr,
edge_attr=edge_attr, edge_attr=edge_attr
show_info=show_info
) )
s = BytesIO() s = BytesIO()
......
...@@ -478,7 +478,7 @@ class TestEngine(unittest.TestCase): ...@@ -478,7 +478,7 @@ class TestEngine(unittest.TestCase):
for ret in other_ret: for ret in other_ret:
self.assertAlmostEqual(ret_origin, ret, places=5) self.assertAlmostEqual(ret_origin, ret, places=5)
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed') @unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self): def test_pandas_categorical(self):
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int "B": np.random.permutation([1, 2, 3] * 100), # int
......
...@@ -16,7 +16,7 @@ except ImportError: ...@@ -16,7 +16,7 @@ except ImportError:
class TestBasic(unittest.TestCase): class TestBasic(unittest.TestCase):
@unittest.skipIf(not matplotlib_installed, 'matplotlib not installed') @unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
def test_plot_importance(self): def test_plot_importance(self):
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1) X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train) train_data = lgb.Dataset(X_train, y_train)
...@@ -62,7 +62,7 @@ class TestBasic(unittest.TestCase): ...@@ -62,7 +62,7 @@ class TestBasic(unittest.TestCase):
def test_plot_tree(self): def test_plot_tree(self):
pass pass
@unittest.skipIf(not matplotlib_installed, 'matplotlib not installed') @unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
def test_plot_metrics(self): def test_plot_metrics(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1) X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train) train_data = lgb.Dataset(X_train, y_train)
......
...@@ -200,7 +200,7 @@ class TestSklearn(unittest.TestCase): ...@@ -200,7 +200,7 @@ class TestSklearn(unittest.TestCase):
except SkipTest as message: except SkipTest as message:
warnings.warn(message, SkipTestWarning) warnings.warn(message, SkipTestWarning)
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed') @unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self): def test_pandas_categorical(self):
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int "B": np.random.permutation([1, 2, 3] * 100), # int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment