plotting.py 11.4 KB
Newer Older
1
2
3
4
5
# coding: utf-8
# pylint: disable = C0103
"""Plotting Library."""
from __future__ import absolute_import

wxchan's avatar
wxchan committed
6
import warnings
7
from copy import deepcopy
wxchan's avatar
wxchan committed
8
9
from io import BytesIO

10
11
import numpy as np

wxchan's avatar
wxchan committed
12
from .basic import Booster
13
14
15
from .sklearn import LGBMModel


16
def check_not_tuple_of_2_elements(obj, obj_name='obj'):
wxchan's avatar
wxchan committed
17
    """check object is not tuple or does not have 2 elements"""
18
19
    if not isinstance(obj, tuple) or len(obj) != 2:
        raise TypeError('%s must be a tuple of 2 elements.' % obj_name)
wxchan's avatar
wxchan committed
20
21


22
23
24
25
def plot_importance(booster, ax=None, height=0.2,
                    xlim=None, ylim=None, title='Feature importance',
                    xlabel='Feature importance', ylabel='Features',
                    importance_type='split', max_num_features=None,
wxchan's avatar
wxchan committed
26
                    ignore_zero=True, figsize=None, grid=True, **kwargs):
27
28
29
30
    """Plot model feature importances.

    Parameters
    ----------
wxchan's avatar
wxchan committed
31
32
    booster : Booster or LGBMModel
        Booster or LGBMModel instance
33
34
35
36
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    height : float
        Bar height, passed to ax.barh()
wxchan's avatar
wxchan committed
37
    xlim : tuple of 2 elements
38
        Tuple passed to axes.xlim()
wxchan's avatar
wxchan committed
39
    ylim : tuple of 2 elements
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        Tuple passed to axes.ylim()
    title : str
        Axes title. Pass None to disable.
    xlabel : str
        X axis title label. Pass None to disable.
    ylabel : str
        Y axis title label. Pass None to disable.
    importance_type : str
        How the importance is calculated: "split" or "gain"
        "split" is the number of times a feature is used in a model
        "gain" is the total gain of splits which use the feature
    max_num_features : int
        Max number of top features displayed on plot.
        If None or smaller than 1, all features will be displayed.
    ignore_zero : bool
        Ignore features with zero importance
wxchan's avatar
wxchan committed
56
57
    figsize : tuple of 2 elements
        Figure size
58
59
60
61
62
63
64
65
66
67
68
69
    grid : bool
        Whether add grid for axes
    **kwargs :
        Other keywords passed to ax.barh()

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
wxchan's avatar
wxchan committed
70
        raise ImportError('You must install matplotlib to plot importance.')
71
72

    if isinstance(booster, LGBMModel):
wxchan's avatar
wxchan committed
73
74
75
76
77
78
        booster = booster.booster_
    elif not isinstance(booster, Booster):
        raise TypeError('booster must be Booster or LGBMModel.')

    importance = booster.feature_importance(importance_type=importance_type)
    feature_name = booster.feature_name()
79
80

    if not len(importance):
wxchan's avatar
wxchan committed
81
        raise ValueError('Booster feature_importances are empty.')
82

wxchan's avatar
wxchan committed
83
    tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
84
85
86
87
88
89
90
    if ignore_zero:
        tuples = [x for x in tuples if x[1] > 0]
    if max_num_features is not None and max_num_features > 0:
        tuples = tuples[-max_num_features:]
    labels, values = zip(*tuples)

    if ax is None:
91
92
        if figsize is not None:
            check_not_tuple_of_2_elements(figsize, 'figsize')
wxchan's avatar
wxchan committed
93
        _, ax = plt.subplots(1, 1, figsize=figsize)
94
95
96
97
98
99
100
101
102
103
104

    ylocs = np.arange(len(values))
    ax.barh(ylocs, values, align='center', height=height, **kwargs)

    for x, y in zip(values, ylocs):
        ax.text(x + 1, y, x, va='center')

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
105
        check_not_tuple_of_2_elements(xlim, 'xlim')
106
107
108
109
110
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
111
        check_not_tuple_of_2_elements(ylim, 'ylim')
112
113
114
115
116
117
118
119
120
121
122
123
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax
wxchan's avatar
wxchan committed
124
125


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def plot_metric(booster, metric=None, dataset_names=None,
                ax=None, xlim=None, ylim=None,
                title='Metric during training',
                xlabel='Iterations', ylabel='auto',
                figsize=None, grid=True):
    """Plot one metric during training.

    Parameters
    ----------
    booster : dict or LGBMModel
        Evals_result recorded by lightgbm.train() or LGBMModel instance
    metric : str or None
        The metric name to plot.
        Only one metric supported because different metrics have various scales.
        Pass None to pick `first` one (according to dict hashcode).
    dataset_names : None or list of str
        List of the dataset names to plot.
        Pass None to plot all datasets.
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    xlim : tuple of 2 elements
        Tuple passed to axes.xlim()
    ylim : tuple of 2 elements
        Tuple passed to axes.ylim()
    title : str
        Axes title. Pass None to disable.
    xlabel : str
        X axis title label. Pass None to disable.
    ylabel : str
        Y axis title label. Pass None to disable. Pass 'auto' to use `metric`.
    figsize : tuple of 2 elements
        Figure size
    grid : bool
        Whether add grid for axes

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('You must install matplotlib to plot metric.')

    if isinstance(booster, LGBMModel):
        eval_results = deepcopy(booster.evals_result_)
    elif isinstance(booster, dict):
        eval_results = deepcopy(booster)
    else:
        raise TypeError('booster must be dict or LGBMModel.')

    num_data = len(eval_results)

    if not num_data:
        raise ValueError('eval results cannot be empty.')

    if ax is None:
        if figsize is not None:
            check_not_tuple_of_2_elements(figsize, 'figsize')
        _, ax = plt.subplots(1, 1, figsize=figsize)

    if dataset_names is None:
        dataset_names = iter(eval_results.keys())
    elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
        raise ValueError('dataset_names should be iterable and cannot be empty')
    else:
        dataset_names = iter(dataset_names)

    name = next(dataset_names)  # take one as sample
    metrics_for_one = eval_results[name]
    num_metric = len(metrics_for_one)
    if metric is None:
        if num_metric > 1:
wxchan's avatar
wxchan committed
199
200
            msg = """more than one metric available, picking one to plot."""
            warnings.warn(msg, stacklevel=2)
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        metric, results = metrics_for_one.popitem()
    else:
        if metric not in metrics_for_one:
            raise KeyError('No given metric in eval results.')
        results = metrics_for_one[metric]
    num_iteration, max_result, min_result = len(results), max(results), min(results)
    x_ = range(num_iteration)
    ax.plot(x_, results, label=name)

    for name in dataset_names:
        metrics_for_one = eval_results[name]
        results = metrics_for_one[metric]
        max_result, min_result = max(max(results), max_result), min(min(results), min_result)
        ax.plot(x_, results, label=name)

    ax.legend(loc='best')

    if xlim is not None:
        check_not_tuple_of_2_elements(xlim, 'xlim')
    else:
        xlim = (0, num_iteration)
    ax.set_xlim(xlim)

    if ylim is not None:
        check_not_tuple_of_2_elements(ylim, 'ylim')
    else:
        range_result = max_result - min_result
        ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
    ax.set_ylim(ylim)

    if ylabel == 'auto':
        ylabel = metric

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


wxchan's avatar
wxchan committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def _to_graphviz(graph, tree_info, show_info, feature_names):
    """Convert specified tree to graphviz instance."""

    def add(root, parent=None, decision=None):
        """recursively add node or edge"""
        if 'split_index' in root:  # non-leaf
            name = 'split' + str(root['split_index'])
            if feature_names is not None:
                label = 'split_feature_name:' + str(feature_names[root['split_feature']])
            else:
                label = 'split_feature_index:' + str(root['split_feature'])
            label += '\nthreshold:' + str(root['threshold'])
            for info in show_info:
                if info in {'split_gain', 'internal_value', 'internal_count'}:
                    label += '\n' + info + ':' + str(root[info])
            graph.node(name, label=label)
            if root['decision_type'] == 'no_greater':
                l_dec, r_dec = '<=', '>'
            elif root['decision_type'] == 'is':
                l_dec, r_dec = 'is', "isn't"
            else:
                raise ValueError('Invalid decision type in tree model.')
            add(root['left_child'], name, l_dec)
            add(root['right_child'], name, r_dec)
        else:  # leaf
            name = 'left' + str(root['leaf_index'])
            label = 'leaf_value:' + str(root['leaf_value'])
            if 'leaf_count' in show_info:
                label += '\nleaf_count:' + str(root['leaf_count'])
            graph.node(name, label=label)
        if parent is not None:
            graph.edge(parent, name, decision)

    add(tree_info['tree_structure'])
    return graph


def plot_tree(booster, ax=None, tree_index=0, figsize=None,
              graph_attr=None, node_attr=None, edge_attr=None,
              show_info=None):
    """Plot specified tree.

    Parameters
    ----------
    booster : Booster, LGBMModel
        Booster or LGBMModel instance.
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    tree_index : int, default 0
        Specify tree index of target tree.
294
    figsize : tuple of 2 elements
wxchan's avatar
wxchan committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        Figure size.
    graph_attr : dict
        Mapping of (attribute, value) pairs for the graph.
    node_attr : dict
        Mapping of (attribute, value) pairs set for all nodes.
    edge_attr : dict
        Mapping of (attribute, value) pairs set for all edges.
    show_info : list
        Information shows on nodes.
        options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib.image as image
    except ImportError:
        raise ImportError('You must install matplotlib to plot tree.')

    try:
        from graphviz import Digraph
    except ImportError:
        raise ImportError('You must install graphviz to plot tree.')

    if ax is None:
322
323
        if figsize is not None:
            check_not_tuple_of_2_elements(figsize, 'figsize')
wxchan's avatar
wxchan committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        _, ax = plt.subplots(1, 1, figsize=figsize)

    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
        raise TypeError('booster must be Booster or LGBMModel.')

    model = booster.dump_model()
    tree_infos = model['tree_info']
    if 'feature_names' in model:
        feature_names = model['feature_names']
    else:
        feature_names = None

    if tree_index < len(tree_infos):
        tree_info = tree_infos[tree_index]
    else:
        raise IndexError('tree_index is out of range.')

    graph = Digraph(graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr)

    if show_info is None:
        show_info = []
    ret = _to_graphviz(graph, tree_info, show_info, feature_names)

    s = BytesIO()
    s.write(ret.pipe(format='png'))
    s.seek(0)
    img = image.imread(s)

    ax.imshow(img)
    ax.axis('off')
    return ax