"python-package/vscode:/vscode.git/clone" did not exist on "eb13f39a0dac5e9bcb82e5fadecfe17f82b8a397"
plotting.py 7.86 KB
Newer Older
1
2
3
4
5
# coding: utf-8
# pylint: disable = C0103
"""Plotting Library."""
from __future__ import absolute_import

wxchan's avatar
wxchan committed
6
7
from io import BytesIO

8
9
10
11
12
13
import numpy as np

from .basic import Booster, is_numpy_1d_array
from .sklearn import LGBMModel


wxchan's avatar
wxchan committed
14
15
16
17
18
def check_not_tuple_of_2_elements(obj):
    """check object is not tuple or does not have 2 elements"""
    return not isinstance(obj, tuple) or len(obj) != 2


19
20
21
22
def plot_importance(booster, ax=None, height=0.2,
                    xlim=None, ylim=None, title='Feature importance',
                    xlabel='Feature importance', ylabel='Features',
                    importance_type='split', max_num_features=None,
wxchan's avatar
wxchan committed
23
                    ignore_zero=True, figsize=None, grid=True, **kwargs):
24
25
26
27
28
29
30
31
32
33
    """Plot model feature importances.

    Parameters
    ----------
    booster : Booster, LGBMModel or array
        Booster or LGBMModel instance, or array of feature importances
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    height : float
        Bar height, passed to ax.barh()
wxchan's avatar
wxchan committed
34
    xlim : tuple of 2 elements
35
        Tuple passed to axes.xlim()
wxchan's avatar
wxchan committed
36
    ylim : tuple of 2 elements
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        Tuple passed to axes.ylim()
    title : str
        Axes title. Pass None to disable.
    xlabel : str
        X axis title label. Pass None to disable.
    ylabel : str
        Y axis title label. Pass None to disable.
    importance_type : str
        How the importance is calculated: "split" or "gain"
        "split" is the number of times a feature is used in a model
        "gain" is the total gain of splits which use the feature
    max_num_features : int
        Max number of top features displayed on plot.
        If None or smaller than 1, all features will be displayed.
    ignore_zero : bool
        Ignore features with zero importance
wxchan's avatar
wxchan committed
53
54
    figsize : tuple of 2 elements
        Figure size
55
56
57
58
59
60
61
62
63
64
65
66
    grid : bool
        Whether add grid for axes
    **kwargs :
        Other keywords passed to ax.barh()

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
wxchan's avatar
wxchan committed
67
        raise ImportError('You must install matplotlib to plot importance.')
68
69
70
71
72
73
74
75

    if isinstance(booster, LGBMModel):
        importance = booster.booster_.feature_importance(importance_type=importance_type)
    elif isinstance(booster, Booster):
        importance = booster.feature_importance(importance_type=importance_type)
    elif is_numpy_1d_array(booster) or isinstance(booster, list):
        importance = booster
    else:
wxchan's avatar
wxchan committed
76
        raise TypeError('booster must be Booster, LGBMModel or array instance.')
77
78

    if not len(importance):
wxchan's avatar
wxchan committed
79
        raise ValueError('Booster feature_importances are empty.')
80
81
82
83
84
85
86
87
88

    tuples = sorted(enumerate(importance), key=lambda x: x[1])
    if ignore_zero:
        tuples = [x for x in tuples if x[1] > 0]
    if max_num_features is not None and max_num_features > 0:
        tuples = tuples[-max_num_features:]
    labels, values = zip(*tuples)

    if ax is None:
wxchan's avatar
wxchan committed
89
90
91
        if figsize is not None and check_not_tuple_of_2_elements(figsize):
            raise TypeError('figsize must be a tuple of 2 elements.')
        _, ax = plt.subplots(1, 1, figsize=figsize)
92
93
94
95
96
97
98
99
100
101
102

    ylocs = np.arange(len(values))
    ax.barh(ylocs, values, align='center', height=height, **kwargs)

    for x, y in zip(values, ylocs):
        ax.text(x + 1, y, x, va='center')

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
wxchan's avatar
wxchan committed
103
104
        if check_not_tuple_of_2_elements(xlim):
            raise TypeError('xlim must be a tuple of 2 elements.')
105
106
107
108
109
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
wxchan's avatar
wxchan committed
110
111
        if check_not_tuple_of_2_elements(ylim):
            raise TypeError('ylim must be a tuple of 2 elements.')
112
113
114
115
116
117
118
119
120
121
122
123
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax
wxchan's avatar
wxchan committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


def _to_graphviz(graph, tree_info, show_info, feature_names):
    """Convert specified tree to graphviz instance."""

    def add(root, parent=None, decision=None):
        """recursively add node or edge"""
        if 'split_index' in root:  # non-leaf
            name = 'split' + str(root['split_index'])
            if feature_names is not None:
                label = 'split_feature_name:' + str(feature_names[root['split_feature']])
            else:
                label = 'split_feature_index:' + str(root['split_feature'])
            label += '\nthreshold:' + str(root['threshold'])
            for info in show_info:
                if info in {'split_gain', 'internal_value', 'internal_count'}:
                    label += '\n' + info + ':' + str(root[info])
            graph.node(name, label=label)
            if root['decision_type'] == 'no_greater':
                l_dec, r_dec = '<=', '>'
            elif root['decision_type'] == 'is':
                l_dec, r_dec = 'is', "isn't"
            else:
                raise ValueError('Invalid decision type in tree model.')
            add(root['left_child'], name, l_dec)
            add(root['right_child'], name, r_dec)
        else:  # leaf
            name = 'left' + str(root['leaf_index'])
            label = 'leaf_value:' + str(root['leaf_value'])
            if 'leaf_count' in show_info:
                label += '\nleaf_count:' + str(root['leaf_count'])
            graph.node(name, label=label)
        if parent is not None:
            graph.edge(parent, name, decision)

    add(tree_info['tree_structure'])
    return graph


def plot_tree(booster, ax=None, tree_index=0, figsize=None,
              graph_attr=None, node_attr=None, edge_attr=None,
              show_info=None):
    """Plot specified tree.

    Parameters
    ----------
    booster : Booster, LGBMModel
        Booster or LGBMModel instance.
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    tree_index : int, default 0
        Specify tree index of target tree.
    figsize : tuple
        Figure size.
    graph_attr : dict
        Mapping of (attribute, value) pairs for the graph.
    node_attr : dict
        Mapping of (attribute, value) pairs set for all nodes.
    edge_attr : dict
        Mapping of (attribute, value) pairs set for all edges.
    show_info : list
        Information shows on nodes.
        options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib.image as image
    except ImportError:
        raise ImportError('You must install matplotlib to plot tree.')

    try:
        from graphviz import Digraph
    except ImportError:
        raise ImportError('You must install graphviz to plot tree.')

    if ax is None:
        if figsize is not None and check_not_tuple_of_2_elements(figsize):
            raise TypeError('xlim must be a tuple of 2 elements.')
        _, ax = plt.subplots(1, 1, figsize=figsize)

    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
        raise TypeError('booster must be Booster or LGBMModel.')

    model = booster.dump_model()
    tree_infos = model['tree_info']
    if 'feature_names' in model:
        feature_names = model['feature_names']
    else:
        feature_names = None

    if tree_index < len(tree_infos):
        tree_info = tree_infos[tree_index]
    else:
        raise IndexError('tree_index is out of range.')

    graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)

    if show_info is None:
        show_info = []
    ret = _to_graphviz(graph, tree_info, show_info, feature_names)

    s = BytesIO()
    s.write(ret.pipe(format='png'))
    s.seek(0)
    img = image.imread(s)

    ax.imshow(img)
    ax.axis('off')
    return ax