Commit 8980fc72 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python-package] add plot tree (#262)

* add plot tree

* add docs

* add example

* add test

* fix test

* fix decision type

* add show_info

* use feature name if available
parent 5c5dce37
...@@ -23,12 +23,12 @@ script: ...@@ -23,12 +23,12 @@ script:
- mkdir build && cd build && cmake .. && make -j - mkdir build && cd build && cmake .. && make -j
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py - cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install - cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py - cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py && python test_plotting.py
- cd $TRAVIS_BUILD_DIR && pep8 --ignore=E501 . - cd $TRAVIS_BUILD_DIR && pep8 --ignore=E501 .
- rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j - rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py - cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
- cd $TRAVIS_BUILD_DIR/python-package && python setup.py install - cd $TRAVIS_BUILD_DIR/python-package && python setup.py install
- cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py - cd $TRAVIS_BUILD_DIR/tests/python_package_test && python test_basic.py && python test_engine.py && python test_sklearn.py && python test_plotting.py
notifications: notifications:
email: false email: false
......
...@@ -928,22 +928,22 @@ The methods of each Class is in alphabetical order. ...@@ -928,22 +928,22 @@ The methods of each Class is in alphabetical order.
##Plotting ##Plotting
####plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, grid=True, **kwargs): ####plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, figsize=None, grid=True, **kwargs):
Plot model feature importances. Plot model feature importances.
Parameters Parameters
---------- ----------
booster : Booster, LGBMModel or array booster : Booster, LGBMModel or array
Booster or LGBMModel instance, or array of feature importances Booster or LGBMModel instance, or array of feature importances.
ax : matplotlib Axes ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created. Target axes instance. If None, new figure and axes will be created.
height : float height : float
Bar height, passed to ax.barh() Bar height, passed to ax.barh().
xlim : tuple xlim : tuple of 2 elements
Tuple passed to axes.xlim() Tuple passed to axes.xlim().
ylim : tuple ylim : tuple of 2 elements
Tuple passed to axes.ylim() Tuple passed to axes.ylim().
title : str title : str
Axes title. Pass None to disable. Axes title. Pass None to disable.
xlabel : str xlabel : str
...@@ -951,18 +951,47 @@ The methods of each Class is in alphabetical order. ...@@ -951,18 +951,47 @@ The methods of each Class is in alphabetical order.
ylabel : str ylabel : str
Y axis title label. Pass None to disable. Y axis title label. Pass None to disable.
importance_type : str importance_type : str
How the importance is calculated: "split" or "gain" How the importance is calculated: "split" or "gain".
"split" is the number of times a feature is used in a model "split" is the number of times a feature is used in a model.
"gain" is the total gain of splits which use the feature "gain" is the total gain of splits which use the feature.
max_num_features : int max_num_features : int
Max number of top features displayed on plot. Max number of top features displayed on plot.
If None or smaller than 1, all features will be displayed. If None or smaller than 1, all features will be displayed.
ignore_zero : bool ignore_zero : bool
Ignore features with zero importance Ignore features with zero importance.
figsize : tuple of 2 elements
Figure size.
grid : bool grid : bool
Whether add grid for axes Whether add grid for axes.
**kwargs : **kwargs :
Other keywords passed to ax.barh() Other keywords passed to ax.barh().
Returns
-------
ax : matplotlib Axes
####plot_tree(booster, ax=None, tree_index=0, figsize=None, graph_attr=None, node_attr=None, edge_attr=None, show_info=None):
Plot specified tree.
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
tree_index : int, default 0
Specify tree index of target tree.
figsize : tuple of 2 elements
Figure size.
graph_attr: dict
Mapping of (attribute, value) pairs for the graph.
node_attr: dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr: dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
Returns Returns
------- -------
......
...@@ -20,6 +20,7 @@ lgb_train = lgb.Dataset(X_train, y_train) ...@@ -20,6 +20,7 @@ lgb_train = lgb.Dataset(X_train, y_train)
# specify your configurations as a dict # specify your configurations as a dict
params = { params = {
'num_leaves': 5,
'verbose': 0 'verbose': 0
} }
...@@ -27,9 +28,16 @@ print('Start training...') ...@@ -27,9 +28,16 @@ print('Start training...')
# train # train
gbm = lgb.train(params, gbm = lgb.train(params,
lgb_train, lgb_train,
num_boost_round=10) num_boost_round=100,
feature_name=['f' + str(i + 1) for i in range(28)],
categorical_feature=[21])
print('Plot feature importances...') print('Plot feature importances...')
# plot feature importances # plot feature importances
ax = lgb.plot_importance(gbm, max_num_features=10) ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show() plt.show()
print('Plot 84th tree...')
# plot tree
lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
plt.show()
...@@ -14,7 +14,7 @@ try: ...@@ -14,7 +14,7 @@ try:
except ImportError: except ImportError:
pass pass
try: try:
from .plotting import plot_importance from .plotting import plot_importance, plot_tree
except ImportError: except ImportError:
pass pass
...@@ -25,4 +25,4 @@ __all__ = ['Dataset', 'Booster', ...@@ -25,4 +25,4 @@ __all__ = ['Dataset', 'Booster',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance'] 'plot_importance', 'plot_tree']
...@@ -3,17 +3,24 @@ ...@@ -3,17 +3,24 @@
"""Plotting Library.""" """Plotting Library."""
from __future__ import absolute_import from __future__ import absolute_import
from io import BytesIO
import numpy as np import numpy as np
from .basic import Booster, is_numpy_1d_array from .basic import Booster, is_numpy_1d_array
from .sklearn import LGBMModel from .sklearn import LGBMModel
def check_not_tuple_of_2_elements(obj):
"""check object is not tuple or does not have 2 elements"""
return not isinstance(obj, tuple) or len(obj) != 2
def plot_importance(booster, ax=None, height=0.2, def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance', xlim=None, ylim=None, title='Feature importance',
xlabel='Feature importance', ylabel='Features', xlabel='Feature importance', ylabel='Features',
importance_type='split', max_num_features=None, importance_type='split', max_num_features=None,
ignore_zero=True, grid=True, **kwargs): ignore_zero=True, figsize=None, grid=True, **kwargs):
"""Plot model feature importances. """Plot model feature importances.
Parameters Parameters
...@@ -24,9 +31,9 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -24,9 +31,9 @@ def plot_importance(booster, ax=None, height=0.2,
Target axes instance. If None, new figure and axes will be created. Target axes instance. If None, new figure and axes will be created.
height : float height : float
Bar height, passed to ax.barh() Bar height, passed to ax.barh()
xlim : tuple xlim : tuple of 2 elements
Tuple passed to axes.xlim() Tuple passed to axes.xlim()
ylim : tuple ylim : tuple of 2 elements
Tuple passed to axes.ylim() Tuple passed to axes.ylim()
title : str title : str
Axes title. Pass None to disable. Axes title. Pass None to disable.
...@@ -43,6 +50,8 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -43,6 +50,8 @@ def plot_importance(booster, ax=None, height=0.2,
If None or smaller than 1, all features will be displayed. If None or smaller than 1, all features will be displayed.
ignore_zero : bool ignore_zero : bool
Ignore features with zero importance Ignore features with zero importance
figsize : tuple of 2 elements
Figure size
grid : bool grid : bool
Whether add grid for axes Whether add grid for axes
**kwargs : **kwargs :
...@@ -55,7 +64,7 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -55,7 +64,7 @@ def plot_importance(booster, ax=None, height=0.2,
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: except ImportError:
raise ImportError('You must install matplotlib for plotting library') raise ImportError('You must install matplotlib to plot importance.')
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
importance = booster.booster_.feature_importance(importance_type=importance_type) importance = booster.booster_.feature_importance(importance_type=importance_type)
...@@ -64,10 +73,10 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -64,10 +73,10 @@ def plot_importance(booster, ax=None, height=0.2,
elif is_numpy_1d_array(booster) or isinstance(booster, list): elif is_numpy_1d_array(booster) or isinstance(booster, list):
importance = booster importance = booster
else: else:
raise ValueError('booster must be Booster or array instance') raise TypeError('booster must be Booster, LGBMModel or array instance.')
if not len(importance): if not len(importance):
raise ValueError('Booster feature_importances are empty') raise ValueError('Booster feature_importances are empty.')
tuples = sorted(enumerate(importance), key=lambda x: x[1]) tuples = sorted(enumerate(importance), key=lambda x: x[1])
if ignore_zero: if ignore_zero:
...@@ -77,7 +86,9 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -77,7 +86,9 @@ def plot_importance(booster, ax=None, height=0.2,
labels, values = zip(*tuples) labels, values = zip(*tuples)
if ax is None: if ax is None:
_, ax = plt.subplots(1, 1) if figsize is not None and check_not_tuple_of_2_elements(figsize):
raise TypeError('figsize must be a tuple of 2 elements.')
_, ax = plt.subplots(1, 1, figsize=figsize)
ylocs = np.arange(len(values)) ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs) ax.barh(ylocs, values, align='center', height=height, **kwargs)
...@@ -89,15 +100,15 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -89,15 +100,15 @@ def plot_importance(booster, ax=None, height=0.2,
ax.set_yticklabels(labels) ax.set_yticklabels(labels)
if xlim is not None: if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim) != 2: if check_not_tuple_of_2_elements(xlim):
raise ValueError('xlim must be a tuple of 2 elements') raise TypeError('xlim must be a tuple of 2 elements.')
else: else:
xlim = (0, max(values) * 1.1) xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim) ax.set_xlim(xlim)
if ylim is not None: if ylim is not None:
if not isinstance(ylim, tuple) or len(ylim) != 2: if check_not_tuple_of_2_elements(ylim):
raise ValueError('ylim must be a tuple of 2 elements') raise TypeError('ylim must be a tuple of 2 elements.')
else: else:
ylim = (-1, len(values)) ylim = (-1, len(values))
ax.set_ylim(ylim) ax.set_ylim(ylim)
...@@ -110,3 +121,118 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -110,3 +121,118 @@ def plot_importance(booster, ax=None, height=0.2,
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.grid(grid) ax.grid(grid)
return ax return ax
def _to_graphviz(graph, tree_info, show_info, feature_names):
"""Convert specified tree to graphviz instance."""
def add(root, parent=None, decision=None):
"""recursively add node or edge"""
if 'split_index' in root: # non-leaf
name = 'split' + str(root['split_index'])
if feature_names is not None:
label = 'split_feature_name:' + str(feature_names[root['split_feature']])
else:
label = 'split_feature_index:' + str(root['split_feature'])
label += '\nthreshold:' + str(root['threshold'])
for info in show_info:
if info in {'split_gain', 'internal_value', 'internal_count'}:
label += '\n' + info + ':' + str(root[info])
graph.node(name, label=label)
if root['decision_type'] == 'no_greater':
l_dec, r_dec = '<=', '>'
elif root['decision_type'] == 'is':
l_dec, r_dec = 'is', "isn't"
else:
raise ValueError('Invalid decision type in tree model.')
add(root['left_child'], name, l_dec)
add(root['right_child'], name, r_dec)
else: # leaf
name = 'left' + str(root['leaf_index'])
label = 'leaf_value:' + str(root['leaf_value'])
if 'leaf_count' in show_info:
label += '\nleaf_count:' + str(root['leaf_count'])
graph.node(name, label=label)
if parent is not None:
graph.edge(parent, name, decision)
add(tree_info['tree_structure'])
return graph
def plot_tree(booster, ax=None, tree_index=0, figsize=None,
graph_attr=None, node_attr=None, edge_attr=None,
show_info=None):
"""Plot specified tree.
Parameters
----------
booster : Booster, LGBMModel
Booster or LGBMModel instance.
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
tree_index : int, default 0
Specify tree index of target tree.
figsize : tuple
Figure size.
graph_attr : dict
Mapping of (attribute, value) pairs for the graph.
node_attr : dict
Mapping of (attribute, value) pairs set for all nodes.
edge_attr : dict
Mapping of (attribute, value) pairs set for all edges.
show_info : list
Information shows on nodes.
options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
import matplotlib.image as image
except ImportError:
raise ImportError('You must install matplotlib to plot tree.')
try:
from graphviz import Digraph
except ImportError:
raise ImportError('You must install graphviz to plot tree.')
if ax is None:
if figsize is not None and check_not_tuple_of_2_elements(figsize):
raise TypeError('xlim must be a tuple of 2 elements.')
_, ax = plt.subplots(1, 1, figsize=figsize)
if isinstance(booster, LGBMModel):
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
model = booster.dump_model()
tree_infos = model['tree_info']
if 'feature_names' in model:
feature_names = model['feature_names']
else:
feature_names = None
if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index]
else:
raise IndexError('tree_index is out of range.')
graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)
if show_info is None:
show_info = []
ret = _to_graphviz(graph, tree_info, show_info, feature_names)
s = BytesIO()
s.write(ret.pipe(format='png'))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
return ax
...@@ -7,15 +7,16 @@ from sklearn.datasets import load_breast_cancer ...@@ -7,15 +7,16 @@ from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
try: try:
from matplotlib.axes import Axes import matplotlib
MATPLOTLIB_INSTALLED = True matplotlib.use('Agg')
matplotlib_installed = True
except ImportError: except ImportError:
MATPLOTLIB_INSTALLED = False matplotlib_installed = False
class TestBasic(unittest.TestCase): class TestBasic(unittest.TestCase):
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib not installed') @unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
def test_plot_importance(self): def test_plot_importance(self):
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1) X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train) train_data = lgb.Dataset(X_train, y_train)
...@@ -27,7 +28,7 @@ class TestBasic(unittest.TestCase): ...@@ -27,7 +28,7 @@ class TestBasic(unittest.TestCase):
} }
gbm0 = lgb.train(params, train_data, num_boost_round=10) gbm0 = lgb.train(params, train_data, num_boost_round=10)
ax0 = lgb.plot_importance(gbm0) ax0 = lgb.plot_importance(gbm0)
self.assertIsInstance(ax0, Axes) self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Feature importance') self.assertEqual(ax0.get_title(), 'Feature importance')
self.assertEqual(ax0.get_xlabel(), 'Feature importance') self.assertEqual(ax0.get_xlabel(), 'Feature importance')
self.assertEqual(ax0.get_ylabel(), 'Features') self.assertEqual(ax0.get_ylabel(), 'Features')
...@@ -37,7 +38,7 @@ class TestBasic(unittest.TestCase): ...@@ -37,7 +38,7 @@ class TestBasic(unittest.TestCase):
gbm1.fit(X_train, y_train) gbm1.fit(X_train, y_train)
ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y') ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
self.assertIsInstance(ax1, Axes) self.assertIsInstance(ax1, matplotlib.axes.Axes)
self.assertEqual(ax1.get_title(), 't') self.assertEqual(ax1.get_title(), 't')
self.assertEqual(ax1.get_xlabel(), 'x') self.assertEqual(ax1.get_xlabel(), 'x')
self.assertEqual(ax1.get_ylabel(), 'y') self.assertEqual(ax1.get_ylabel(), 'y')
...@@ -48,7 +49,7 @@ class TestBasic(unittest.TestCase): ...@@ -48,7 +49,7 @@ class TestBasic(unittest.TestCase):
ax2 = lgb.plot_importance(gbm0.feature_importance(), ax2 = lgb.plot_importance(gbm0.feature_importance(),
color=['r', 'y', 'g', 'b'], color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None) title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, Axes) self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '') self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '') self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '') self.assertEqual(ax2.get_ylabel(), '')
...@@ -58,6 +59,10 @@ class TestBasic(unittest.TestCase): ...@@ -58,6 +59,10 @@ class TestBasic(unittest.TestCase):
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b
@unittest.skip('Graphviz are not executables on Travis')
def test_plot_tree(self):
pass
print("----------------------------------------------------------------------") print("----------------------------------------------------------------------")
print("running test_plotting.py") print("running test_plotting.py")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment