Commit abaefb54 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python-package] add plot importance (#237)

* add plot importance

* add plot example
parent 46d4eecf
...@@ -14,7 +14,7 @@ before_install: ...@@ -14,7 +14,7 @@ before_install:
install: install:
- sudo apt-get install -y libopenmpi-dev openmpi-bin build-essential - sudo apt-get install -y libopenmpi-dev openmpi-bin build-essential
- conda install --yes atlas numpy scipy scikit-learn pandas - conda install --yes atlas numpy scipy scikit-learn pandas matplotlib
- pip install pep8 - pip install pep8
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
+ [record_evaluation](Python-API.md#record_evaluationeval_result) + [record_evaluation](Python-API.md#record_evaluationeval_result)
+ [early_stopping](Python-API.md#early_stoppingstopping_rounds-verbosetrue) + [early_stopping](Python-API.md#early_stoppingstopping_rounds-verbosetrue)
* [Plotting](Python-API.md#plotting)
The methods of each Class is in alphabetical order. The methods of each Class is in alphabetical order.
---- ----
...@@ -347,6 +349,13 @@ The methods of each Class is in alphabetical order. ...@@ -347,6 +349,13 @@ The methods of each Class is in alphabetical order.
Feature importances. Feature importances.
Parameters
----------
importance_type : str, default "split"
How the importance is calculated: "split" or "gain"
"split" is the number of times a feature is used in a model
"gain" is the total gain of splits which use the feature
Returns Returns
------- -------
result : array result : array
...@@ -916,3 +925,45 @@ The methods of each Class is in alphabetical order. ...@@ -916,3 +925,45 @@ The methods of each Class is in alphabetical order.
------- -------
callback : function callback : function
The requested callback function. The requested callback function.
##Plotting
####plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, grid=True, **kwargs):
Plot model feature importances.
Parameters
----------
booster : Booster, LGBMModel or array
Booster or LGBMModel instance, or array of feature importances
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
height : float
Bar height, passed to ax.barh()
xlim : tuple
Tuple passed to axes.xlim()
ylim : tuple
Tuple passed to axes.ylim()
title : str
Axes title. Pass None to disable.
xlabel : str
X axis title label. Pass None to disable.
ylabel : str
Y axis title label. Pass None to disable.
importance_type : str
How the importance is calculated: "split" or "gain"
"split" is the number of times a feature is used in a model
"gain" is the total gain of splits which use the feature
max_num_features : int
Max number of top features displayed on plot.
If None or smaller than 1, all features will be displayed.
ignore_zero : bool
Ignore features with zero importance
grid : bool
Whether add grid for axes
**kwargs :
Other keywords passed to ax.barh()
Returns
-------
ax : matplotlib Axes
...@@ -6,10 +6,9 @@ Here is an example for LightGBM to use python package. ...@@ -6,10 +6,9 @@ Here is an example for LightGBM to use python package.
For the installation, check the wiki [here](https://github.com/Microsoft/LightGBM/wiki/Installation-Guide). For the installation, check the wiki [here](https://github.com/Microsoft/LightGBM/wiki/Installation-Guide).
You also need scikit-learn and pandas to run the examples, but they are not required for the package itself. You can install them with pip: You also need scikit-learn, pandas and matplotlib (only for plot example) to run the examples, but they are not required for the package itself. You can install them with pip:
``` ```
pip install -U scikit-learn pip install scikit-learn pandas matplotlib -U
pip install -U pandas
``` ```
Now you can run examples in this folder, for example: Now you can run examples in this folder, for example:
......
# coding: utf-8
# pylint: disable = invalid-name, C0111
import lightgbm as lgb
import pandas as pd
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You need to install matplotlib for plot_example.py.')
# load or create your dataset
print('Load data...')
df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
y_train = df_train[0]
X_train = df_train.drop(0, axis=1)
# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
# specify your configurations as a dict
params = {
'verbose': 0
}
print('Start training...')
# train
gbm = lgb.train(params,
lgb_train,
num_boost_round=10)
print('Plot feature importances...')
# plot feature importances
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()
...@@ -13,6 +13,10 @@ try: ...@@ -13,6 +13,10 @@ try:
from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker from .sklearn import LGBMModel, LGBMRegressor, LGBMClassifier, LGBMRanker
except ImportError: except ImportError:
pass pass
try:
from .plotting import plot_importance
except ImportError:
pass
__version__ = 0.1 __version__ = 0.1
...@@ -20,4 +24,5 @@ __version__ = 0.1 ...@@ -20,4 +24,5 @@ __version__ = 0.1
__all__ = ['Dataset', 'Booster', __all__ = ['Dataset', 'Booster',
'train', 'cv', 'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping'] 'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance']
...@@ -1550,6 +1550,13 @@ class Booster(object): ...@@ -1550,6 +1550,13 @@ class Booster(object):
""" """
Feature importances Feature importances
Parameters
----------
importance_type : str, default "split"
How the importance is calculated: "split" or "gain"
"split" is the number of times a feature is used in a model
"gain" is the total gain of splits which use the feature
Returns Returns
------- -------
Array of feature importances Array of feature importances
......
# coding: utf-8
# pylint: disable = C0103
"""Plotting Library."""
from __future__ import absolute_import
import numpy as np
from .basic import Booster, is_numpy_1d_array
from .sklearn import LGBMModel
def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance',
xlabel='Feature importance', ylabel='Features',
importance_type='split', max_num_features=None,
ignore_zero=True, grid=True, **kwargs):
"""Plot model feature importances.
Parameters
----------
booster : Booster, LGBMModel or array
Booster or LGBMModel instance, or array of feature importances
ax : matplotlib Axes
Target axes instance. If None, new figure and axes will be created.
height : float
Bar height, passed to ax.barh()
xlim : tuple
Tuple passed to axes.xlim()
ylim : tuple
Tuple passed to axes.ylim()
title : str
Axes title. Pass None to disable.
xlabel : str
X axis title label. Pass None to disable.
ylabel : str
Y axis title label. Pass None to disable.
importance_type : str
How the importance is calculated: "split" or "gain"
"split" is the number of times a feature is used in a model
"gain" is the total gain of splits which use the feature
max_num_features : int
Max number of top features displayed on plot.
If None or smaller than 1, all features will be displayed.
ignore_zero : bool
Ignore features with zero importance
grid : bool
Whether add grid for axes
**kwargs :
Other keywords passed to ax.barh()
Returns
-------
ax : matplotlib Axes
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('You must install matplotlib for plotting library')
if isinstance(booster, LGBMModel):
importance = booster.booster_.feature_importance(importance_type=importance_type)
elif isinstance(booster, Booster):
importance = booster.feature_importance(importance_type=importance_type)
elif is_numpy_1d_array(booster) or isinstance(booster, list):
importance = booster
else:
raise ValueError('booster must be Booster or array instance')
if not len(importance):
raise ValueError('Booster feature_importances are empty')
tuples = sorted(enumerate(importance), key=lambda x: x[1])
if ignore_zero:
tuples = [x for x in tuples if x[1] > 0]
if max_num_features is not None and max_num_features > 0:
tuples = tuples[-max_num_features:]
labels, values = zip(*tuples)
if ax is None:
_, ax = plt.subplots(1, 1)
ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs)
for x, y in zip(values, ylocs):
ax.text(x + 1, y, x, va='center')
ax.set_yticks(ylocs)
ax.set_yticklabels(labels)
if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim) != 2:
raise ValueError('xlim must be a tuple of 2 elements')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if ylim is not None:
if not isinstance(ylim, tuple) or len(ylim) != 2:
raise ValueError('ylim must be a tuple of 2 elements')
else:
ylim = (-1, len(values))
ax.set_ylim(ylim)
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
# coding: utf-8
# pylint: skip-file
import unittest
import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
try:
from matplotlib.axes import Axes
MATPLOTLIB_INSTALLED = True
except ImportError:
MATPLOTLIB_INSTALLED = False
class TestBasic(unittest.TestCase):
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib not installed')
def test_plot_importance(self):
X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
train_data = lgb.Dataset(X_train, y_train)
params = {
"objective": "binary",
"verbose": -1,
"num_leaves": 3
}
gbm0 = lgb.train(params, train_data, num_boost_round=10)
ax0 = lgb.plot_importance(gbm0)
self.assertIsInstance(ax0, Axes)
self.assertEqual(ax0.get_title(), 'Feature importance')
self.assertEqual(ax0.get_xlabel(), 'Feature importance')
self.assertEqual(ax0.get_ylabel(), 'Features')
self.assertLessEqual(len(ax0.patches), 30)
gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(X_train, y_train)
ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
self.assertIsInstance(ax1, Axes)
self.assertEqual(ax1.get_title(), 't')
self.assertEqual(ax1.get_xlabel(), 'x')
self.assertEqual(ax1.get_ylabel(), 'y')
self.assertLessEqual(len(ax1.patches), 30)
for patch in ax1.patches:
self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red
ax2 = lgb.plot_importance(gbm0.feature_importance(),
color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
self.assertLessEqual(len(ax2.patches), 30)
self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.)) # r
self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.)) # y
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b
print("----------------------------------------------------------------------")
print("running test_plotting.py")
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment