Commit 053d9d8a authored by Tsukasa OMOTO's avatar Tsukasa OMOTO Committed by Guolin Ke
Browse files

python-package: add graphviz.Digraph parameters (#400)

* python-package: add graphviz.Digraph parameters

* examples: add a plottig example with graphviz

* fix tree index in print
parent 223b164e
...@@ -1046,24 +1046,45 @@ The methods of each Class is in alphabetical order. ...@@ -1046,24 +1046,45 @@ The methods of each Class is in alphabetical order.
------- -------
ax : matplotlib Axes ax : matplotlib Axes
#### create_tree_digraph(booster, tree_index=0, graph_attr=None, node_attr=None, edge_attr=None, show_info=None): #### create_tree_digraph(booster, tree_index=0, show_info=None, name=None, comment=None, filename=None, directory=None, format=None, engine=None, encoding=None, graph_attr=None, node_attr=None, edge_attr=None, body=None, strict=False):
Create a digraph of specified tree. Create a digraph of specified tree.
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
Parameters Parameters
---------- ----------
booster : Booster, LGBMModel booster : Booster, LGBMModel
Booster or LGBMModel instance. Booster or LGBMModel instance.
tree_index : int, default 0 tree_index : int, default 0
Specify tree index of target tree. Specify tree index of target tree.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
name : str
Graph name used in the source code.
comment : str
Comment added to the first line of the source.
filename : str
Filename for saving the source (defaults to name + '.gv').
directory : str
(Sub)directory for source saving and rendering.
format : str
Rendering output format ('pdf', 'png', ...).
engine : str
Layout command used ('dot', 'neato', ...).
encoding : str
Encoding for saving the source.
graph_attr : dict graph_attr : dict
Mapping of (attribute, value) pairs for the graph. Mapping of (attribute, value) pairs for the graph.
node_attr : dict node_attr : dict
Mapping of (attribute, value) pairs set for all nodes. Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict edge_attr : dict
Mapping of (attribute, value) pairs set for all edges. Mapping of (attribute, value) pairs set for all edges.
show_info : list body : list of str
Information shows on nodes. Iterable of lines to add to the graph body.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'. strict : bool
Iterable of lines to add to the graph body.
Returns Returns
------- -------
......
...@@ -53,3 +53,7 @@ plt.show() ...@@ -53,3 +53,7 @@ plt.show()
print('Plot 84th tree...') # one tree use categorical feature to split print('Plot 84th tree...') # one tree use categorical feature to split
ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain']) ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
plt.show() plt.show()
print('Plot 84th tree with graphviz...')
graph = lgb.create_tree_digraph(gbm, tree_index=83, name='Tree84')
graph.render(view=True)
...@@ -242,8 +242,14 @@ def plot_metric(booster, metric=None, dataset_names=None, ...@@ -242,8 +242,14 @@ def plot_metric(booster, metric=None, dataset_names=None,
def _to_graphviz(tree_info, show_info, feature_names, def _to_graphviz(tree_info, show_info, feature_names,
graph_attr=None, node_attr=None, edge_attr=None): name=None, comment=None, filename=None, directory=None,
"""Convert specified tree to graphviz instance.""" format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False):
"""Convert specified tree to graphviz instance.
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
"""
try: try:
from graphviz import Digraph from graphviz import Digraph
except ImportError: except ImportError:
...@@ -279,31 +285,56 @@ def _to_graphviz(tree_info, show_info, feature_names, ...@@ -279,31 +285,56 @@ def _to_graphviz(tree_info, show_info, feature_names,
if parent is not None: if parent is not None:
graph.edge(parent, name, decision) graph.edge(parent, name, decision)
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr) graph = Digraph(name=name, comment=comment, filename=filename, directory=directory,
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
add(tree_info['tree_structure']) add(tree_info['tree_structure'])
return graph return graph
def create_tree_digraph(booster, tree_index=0, graph_attr=None, def create_tree_digraph(booster, tree_index=0, show_info=None,
node_attr=None, edge_attr=None, show_info=None): name=None, comment=None, filename=None, directory=None,
format=None, engine=None, encoding=None, graph_attr=None,
node_attr=None, edge_attr=None, body=None, strict=False):
"""Create a digraph of specified tree. """Create a digraph of specified tree.
See:
- http://graphviz.readthedocs.io/en/stable/api.html#digraph
Parameters Parameters
---------- ----------
booster : Booster, LGBMModel booster : Booster, LGBMModel
Booster or LGBMModel instance. Booster or LGBMModel instance.
tree_index : int, default 0 tree_index : int, default 0
Specify tree index of target tree. Specify tree index of target tree.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
name : str
Graph name used in the source code.
comment : str
Comment added to the first line of the source.
filename : str
Filename for saving the source (defaults to name + '.gv').
directory : str
(Sub)directory for source saving and rendering.
format : str
Rendering output format ('pdf', 'png', ...).
engine : str
Layout command used ('dot', 'neato', ...).
encoding : str
Encoding for saving the source.
graph_attr : dict graph_attr : dict
Mapping of (attribute, value) pairs for the graph. Mapping of (attribute, value) pairs for the graph.
node_attr : dict node_attr : dict
Mapping of (attribute, value) pairs set for all nodes. Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict edge_attr : dict
Mapping of (attribute, value) pairs set for all edges. Mapping of (attribute, value) pairs set for all edges.
show_info : list body : list of str
Information shows on nodes. Iterable of lines to add to the graph body.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'. strict : bool
Iterable of lines to add to the graph body.
Returns Returns
------- -------
...@@ -330,7 +361,9 @@ def create_tree_digraph(booster, tree_index=0, graph_attr=None, ...@@ -330,7 +361,9 @@ def create_tree_digraph(booster, tree_index=0, graph_attr=None,
show_info = [] show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names, graph = _to_graphviz(tree_info, show_info, feature_names,
graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr) name=name, comment=comment, filename=filename, directory=directory,
format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
return graph return graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment