Commit 61ad133b authored by Tsukasa OMOTO's avatar Tsukasa OMOTO Committed by Guolin Ke
Browse files

python-package: refine plot_tree (#388)

* python-package: refine plot_tree

This change splits plot_tree to two methods:
1. method of creating digraph of tree
2. method of plotting the digraph with matplotlib

* fix doc

* fix doc
parent 546056e7
...@@ -1045,3 +1045,27 @@ The methods of each Class is in alphabetical order. ...@@ -1045,3 +1045,27 @@ The methods of each Class is in alphabetical order.
Returns Returns
------- -------
ax : matplotlib Axes ax : matplotlib Axes
#### create_tree_digraph(booster, tree_index=0, graph_attr=None, node_attr=None, edge_attr=None, show_info=None):
Create a digraph of specified tree.
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
tree_index : int, default 0
Specify tree index of target tree.
graph_attr : dict
Mapping of (attribute, value) pairs for the graph.
node_attr : dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
Returns
-------
graph : graphviz Digraph
...@@ -16,7 +16,7 @@ try: ...@@ -16,7 +16,7 @@ try:
except ImportError: except ImportError:
pass pass
try: try:
from .plotting import plot_importance, plot_metric, plot_tree from .plotting import plot_importance, plot_metric, plot_tree, create_tree_digraph
except ImportError: except ImportError:
pass pass
...@@ -27,4 +27,4 @@ __all__ = ['Dataset', 'Booster', ...@@ -27,4 +27,4 @@ __all__ = ['Dataset', 'Booster',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance', 'plot_metric', 'plot_tree'] 'plot_importance', 'plot_metric', 'plot_tree', 'create_tree_digraph']
...@@ -241,8 +241,13 @@ def plot_metric(booster, metric=None, dataset_names=None, ...@@ -241,8 +241,13 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax return ax
def _to_graphviz(graph, tree_info, show_info, feature_names): def _to_graphviz(tree_info, show_info, feature_names,
graph_attr=None, node_attr=None, edge_attr=None):
"""Convert specified tree to graphviz instance.""" """Convert specified tree to graphviz instance."""
try:
from graphviz import Digraph
except ImportError:
raise ImportError('You must install graphviz to plot tree.')
def add(root, parent=None, decision=None): def add(root, parent=None, decision=None):
"""recursively add node or edge""" """recursively add node or edge"""
...@@ -274,7 +279,59 @@ def _to_graphviz(graph, tree_info, show_info, feature_names): ...@@ -274,7 +279,59 @@ def _to_graphviz(graph, tree_info, show_info, feature_names):
if parent is not None: if parent is not None:
graph.edge(parent, name, decision) graph.edge(parent, name, decision)
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
add(tree_info['tree_structure']) add(tree_info['tree_structure'])
return graph
def create_tree_digraph(booster, tree_index=0, graph_attr=None,
node_attr=None, edge_attr=None, show_info=None):
"""Create a digraph of specified tree.
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
tree_index : int, default 0
Specify tree index of target tree.
graph_attr : dict
Mapping of (attribute, value) pairs for the graph.
node_attr : dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
Returns
-------
graph : graphviz Digraph
"""
if isinstance(booster, LGBMModel):
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
model = booster.dump_model()
tree_infos = model['tree_info']
if 'feature_names' in model:
feature_names = model['feature_names']
else:
feature_names = None
if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index]
else:
raise IndexError('tree_index is out of range.')
if show_info is None:
show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names,
graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
return graph return graph
...@@ -313,41 +370,22 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -313,41 +370,22 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
except ImportError: except ImportError:
raise ImportError('You must install matplotlib to plot tree.') raise ImportError('You must install matplotlib to plot tree.')
try:
from graphviz import Digraph
except ImportError:
raise ImportError('You must install graphviz to plot tree.')
if ax is None: if ax is None:
if figsize is not None: if figsize is not None:
check_not_tuple_of_2_elements(figsize, 'figsize') check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize) _, ax = plt.subplots(1, 1, figsize=figsize)
if isinstance(booster, LGBMModel): graph = create_tree_digraph(
booster = booster.booster_ booster=booster,
elif not isinstance(booster, Booster): tree_index=tree_index,
raise TypeError('booster must be Booster or LGBMModel.') graph_attr=graph_attr,
node_attr=node_attr,
model = booster.dump_model() edge_attr=edge_attr,
tree_infos = model['tree_info'] show_info=show_info
if 'feature_names' in model: )
feature_names = model['feature_names']
else:
feature_names = None
if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index]
else:
raise IndexError('tree_index is out of range.')
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
if show_info is None:
show_info = []
ret = _to_graphviz(graph, tree_info, show_info, feature_names)
s = BytesIO() s = BytesIO()
s.write(ret.pipe(format='png')) s.write(graph.pipe(format='png'))
s.seek(0) s.seek(0)
img = image.imread(s) img = image.imread(s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment