Commit 611cf5d4 authored by Nikita Titov's avatar Nikita Titov Committed by Tsukasa OMOTO
Browse files

[python] added plot_split_value_histogram function (#2043)

* added plot_split_value_histogram function

* updated init module

* added plot split value histogram example

* added plot_split_value_histogram to notebook

* added test

* fixed pylint

* updated API docs

* fixed grammar

* set y ticks to int value in more sufficient way
parent 65c7779d
...@@ -62,6 +62,8 @@ Plotting ...@@ -62,6 +62,8 @@ Plotting
.. autofunction:: lightgbm.plot_importance .. autofunction:: lightgbm.plot_importance
.. autofunction:: lightgbm.plot_split_value_histogram
.. autofunction:: lightgbm.plot_metric .. autofunction:: lightgbm.plot_metric
.. autofunction:: lightgbm.plot_tree .. autofunction:: lightgbm.plot_tree
......
...@@ -57,5 +57,6 @@ Examples include: ...@@ -57,5 +57,6 @@ Examples include:
- Train and record eval results for further plotting - Train and record eval results for further plotting
- Plot metrics recorded during training - Plot metrics recorded during training
- Plot feature importances - Plot feature importances
- Plot split value histogram
- Plot one specified tree - Plot one specified tree
- Plot one specified tree with Graphviz - Plot one specified tree with Graphviz
...@@ -50,6 +50,10 @@ print('Plotting feature importances...') ...@@ -50,6 +50,10 @@ print('Plotting feature importances...')
ax = lgb.plot_importance(gbm, max_num_features=10) ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show() plt.show()
print('Plotting split value histogram...')
ax = lgb.plot_split_value_histogram(gbm, feature='f26', bins='auto')
plt.show()
print('Plotting 54th tree...') # one tree use categorical feature to split print('Plotting 54th tree...') # one tree use categorical feature to split
ax = lgb.plot_tree(gbm, tree_index=53, figsize=(15, 15), show_info=['split_gain']) ax = lgb.plot_tree(gbm, tree_index=53, figsize=(15, 15), show_info=['split_gain'])
plt.show() plt.show()
......
...@@ -19,7 +19,8 @@ try: ...@@ -19,7 +19,8 @@ try:
except ImportError: except ImportError:
pass pass
try: try:
from .plotting import plot_importance, plot_metric, plot_tree, create_tree_digraph from .plotting import (plot_importance, plot_split_value_histogram, plot_metric,
plot_tree, create_tree_digraph)
except ImportError: except ImportError:
pass pass
...@@ -34,7 +35,7 @@ __all__ = ['Dataset', 'Booster', ...@@ -34,7 +35,7 @@ __all__ = ['Dataset', 'Booster',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance', 'plot_metric', 'plot_tree', 'create_tree_digraph'] 'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']
# REMOVEME: remove warning after 2.3.0 version release # REMOVEME: remove warning after 2.3.0 version release
if system() == 'Darwin': if system() == 'Darwin':
......
...@@ -141,6 +141,110 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -141,6 +141,110 @@ def plot_importance(booster, ax=None, height=0.2,
return ax return ax
def plot_split_value_histogram(booster, feature, bins=None, ax=None, width_coef=0.8,
xlim=None, ylim=None,
title='Split value histogram for feature with @index/name@ @feature@',
xlabel='Feature split value', ylabel='Count',
figsize=None, grid=True, **kwargs):
"""Plot split value histogram for the specified feature of the model.
Parameters
----------
booster : Booster or LGBMModel
Booster or LGBMModel instance of which feature split value histogram should be plotted.
feature : int or string
The feature name or index the histogram is plotted for.
If int, interpreted as index.
If string, interpreted as name.
bins : int, string or None, optional (default=None)
The maximum number of bins.
If None, the number of bins equals number of unique split values.
If string, it should be one from the list of the supported values by ``numpy.histogram()`` function.
ax : matplotlib.axes.Axes or None, optional (default=None)
Target axes instance.
If None, new figure and axes will be created.
width_coef : float, optional (default=0.8)
Coefficient for histogram bar width.
xlim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.xlim()``.
ylim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.ylim()``.
title : string or None, optional (default="Split value histogram for feature with @index/name@ @feature@")
Axes title.
If None, title is disabled.
@feature@ placeholder can be used, and it will be replaced with the value of ``feature`` parameter.
@index/name@ placeholder can be used,
and it will be replaced with ``index`` word in case of ``int`` type ``feature`` parameter
or ``name`` word in case of ``string`` type ``feature`` parameter.
xlabel : string or None, optional (default="Feature split value")
X-axis title label.
If None, title is disabled.
ylabel : string or None, optional (default="Count")
Y-axis title label.
If None, title is disabled.
figsize : tuple of 2 elements or None, optional (default=None)
Figure size.
grid : bool, optional (default=True)
Whether to add a grid for axes.
**kwargs
Other parameters passed to ``ax.bar()``.
Returns
-------
ax : matplotlib.axes.Axes
The plot with specified model's feature split value histogram.
"""
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
else:
raise ImportError('You must install matplotlib to plot split value histogram.')
if isinstance(booster, LGBMModel):
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
hist, bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
if np.count_nonzero(hist) == 0:
raise ValueError('Cannot plot split value histogram, '
'because feature {} was not used in splitting'.format(feature))
width = width_coef * (bins[1] - bins[0])
centred = (bins[:-1] + bins[1:]) / 2
if ax is None:
if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
ax.bar(centred, hist, align='center', width=width, **kwargs)
if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim')
else:
range_result = bins[-1] - bins[0]
xlim = (bins[0] - range_result * 0.2, bins[-1] + range_result * 0.2)
ax.set_xlim(xlim)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
if ylim is not None:
_check_not_tuple_of_2_elements(ylim, 'ylim')
else:
ylim = (0, max(hist) * 1.1)
ax.set_ylim(ylim)
if title is not None:
title = title.replace('@feature@', str(feature))
title = title.replace('@index/name@', ('name' if isinstance(feature, string_type) else 'index'))
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
def plot_metric(booster, metric=None, dataset_names=None, def plot_metric(booster, metric=None, dataset_names=None,
ax=None, xlim=None, ylim=None, ax=None, xlim=None, ylim=None,
title='Metric during training', title='Metric during training',
......
...@@ -60,6 +60,45 @@ class TestBasic(unittest.TestCase): ...@@ -60,6 +60,45 @@ class TestBasic(unittest.TestCase):
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_split_value_histogram(self):
gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
ax0 = lgb.plot_split_value_histogram(gbm0, 27)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Split value histogram for feature with index 27')
self.assertEqual(ax0.get_xlabel(), 'Feature split value')
self.assertEqual(ax0.get_ylabel(), 'Count')
self.assertLessEqual(len(ax0.patches), 2)
gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(self.X_train, self.y_train)
ax1 = lgb.plot_split_value_histogram(gbm1, gbm1.booster_.feature_name()[27], figsize=(10, 5),
title='Histogram for feature @index/name@ @feature@',
xlabel='x', ylabel='y', color='r')
self.assertIsInstance(ax1, matplotlib.axes.Axes)
self.assertEqual(ax1.get_title(),
'Histogram for feature name {}'.format(gbm1.booster_.feature_name()[27]))
self.assertEqual(ax1.get_xlabel(), 'x')
self.assertEqual(ax1.get_ylabel(), 'y')
self.assertLessEqual(len(ax1.patches), 2)
for patch in ax1.patches:
self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red
ax2 = lgb.plot_split_value_histogram(gbm0, 27, bins=10, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
self.assertEqual(len(ax2.patches), 10)
self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.)) # r
self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.)) # y
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b
self.assertRaises(ValueError, lgb.plot_split_value_histogram, gbm0, 0) # was not used in splitting
@unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed') @unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed')
def test_plot_tree(self): def test_plot_tree(self):
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment