Commit 69a36605 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] make tree rendering more clear (#1424)

* fixed grammar

* fixed params description in graph plotting functions

* clarified types of attributes in their descriptions

* increased readability of graphs by adding spaces

* added precision parameter to plot tree functions
parent 42d5e571
......@@ -83,7 +83,7 @@ def plot_importance(booster, ax=None, height=0.2,
feature_name = booster.feature_name()
if not len(importance):
raise ValueError('Booster feature_importances are empty.')
raise ValueError("Booster's feature_importance is empty.")
tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
if ignore_zero:
......@@ -252,7 +252,7 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax
def _to_graphviz(tree_info, show_info, feature_names,
def _to_graphviz(tree_info, show_info, feature_names, precision=None,
name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False):
......@@ -266,18 +266,23 @@ def _to_graphviz(tree_info, show_info, feature_names,
except ImportError:
raise ImportError('You must install graphviz to plot tree.')
def float2str(value, precision=None):
return "{0:.{1}f}".format(value, precision) if precision is not None else str(value)
def add(root, parent=None, decision=None):
"""recursively add node or edge"""
if 'split_index' in root: # non-leaf
name = 'split' + str(root['split_index'])
name = 'split{0}'.format(root['split_index'])
if feature_names is not None:
label = 'split_feature_name:' + str(feature_names[root['split_feature']])
label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']])
else:
label = 'split_feature_index:' + str(root['split_feature'])
label += r'\nthreshold:' + str(root['threshold'])
label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(float2str(root['threshold'], precision))
for info in show_info:
if info in {'split_gain', 'internal_value', 'internal_count'}:
label += r'\n' + info + ':' + str(root[info])
if info in {'split_gain', 'internal_value'}:
label += r'\n{0}: {1}'.format(info, float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
graph.node(name, label=label)
if root['decision_type'] == '<=':
l_dec, r_dec = '<=', '>'
......@@ -288,11 +293,11 @@ def _to_graphviz(tree_info, show_info, feature_names,
add(root['left_child'], name, l_dec)
add(root['right_child'], name, r_dec)
else: # leaf
name = 'leaf' + str(root['leaf_index'])
label = 'leaf_index:' + str(root['leaf_index'])
label += r'\nleaf_value:' + str(root['leaf_value'])
name = 'leaf{0}'.format(root['leaf_index'])
label = 'leaf_index: {0}'.format(root['leaf_index'])
label += r'\nleaf_value: {0}'.format(float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info:
label += r'\nleaf_count:' + str(root['leaf_count'])
label += r'\nleaf_count: {0}'.format(root['leaf_count'])
graph.node(name, label=label)
if parent is not None:
graph.edge(parent, name, decision)
......@@ -305,7 +310,7 @@ def _to_graphviz(tree_info, show_info, feature_names,
return graph
def create_tree_digraph(booster, tree_index=0, show_info=None,
def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False):
......@@ -322,9 +327,11 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
Booster or LGBMModel instance.
tree_index : int, optional (default=0)
The index of a target tree to convert.
show_info : list or None, optional (default=None)
What information should be showed on nodes.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
name : string or None, optional (default=None)
Graph name used in the source code.
comment : string or None, optional (default=None)
......@@ -340,12 +347,15 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
Layout command used ('dot', 'neato', ...).
encoding : string or None, optional (default=None)
Encoding for saving the source.
graph_attr : dict or None, optional (default=None)
graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph.
node_attr : dict or None, optional (default=None)
All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict or None, optional (default=None)
All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges.
All attributes and values must be strings or bytes-like objects.
body : list of strings or None, optional (default=None)
Lines to add to the graph body.
strict : bool, optional (default=False)
......@@ -376,7 +386,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
if show_info is None:
show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names,
graph = _to_graphviz(tree_info, show_info, feature_names, precision,
name=name, comment=comment, filename=filename, directory=directory,
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
......@@ -386,7 +396,7 @@ def create_tree_digraph(booster, tree_index=0, show_info=None,
def plot_tree(booster, ax=None, tree_index=0, figsize=None,
graph_attr=None, node_attr=None, edge_attr=None,
show_info=None):
show_info=None, precision=None):
"""Plot specified tree.
Parameters
......@@ -400,15 +410,20 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
The index of a target tree to plot.
figsize : tuple of 2 elements or None, optional (default=None)
Figure size.
graph_attr : dict or None, optional (default=None)
graph_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for the graph.
node_attr : dict or None, optional (default=None)
All attributes and values must be strings or bytes-like objects.
node_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict or None, optional (default=None)
All attributes and values must be strings or bytes-like objects.
edge_attr : dict, list of tuples or None, optional (default=None)
Mapping of (attribute, value) pairs set for all edges.
show_info : list or None, optional (default=None)
What information should be showed on nodes.
All attributes and values must be strings or bytes-like objects.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
Returns
-------
......@@ -429,10 +444,11 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
graph = create_tree_digraph(
booster=booster,
tree_index=tree_index,
show_info=show_info,
precision=precision,
graph_attr=graph_attr,
node_attr=node_attr,
edge_attr=edge_attr,
show_info=show_info
edge_attr=edge_attr
)
s = BytesIO()
......
......@@ -478,7 +478,7 @@ class TestEngine(unittest.TestCase):
for ret in other_ret:
self.assertAlmostEqual(ret_origin, ret, places=5)
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed')
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self):
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
......
......@@ -16,7 +16,7 @@ except ImportError:
class TestBasic(unittest.TestCase):
@unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
@unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
def test_plot_importance(self):
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)
......@@ -62,7 +62,7 @@ class TestBasic(unittest.TestCase):
def test_plot_tree(self):
pass
@unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
@unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
def test_plot_metrics(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)
......
......@@ -200,7 +200,7 @@ class TestSklearn(unittest.TestCase):
except SkipTest as message:
warnings.warn(message, SkipTestWarning)
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed')
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self):
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment