plotting.py 14 KB
Newer Older
1
2
3
4
5
# coding: utf-8
# pylint: disable = C0103
"""Plotting Library."""
from __future__ import absolute_import

wxchan's avatar
wxchan committed
6
import warnings
7
from copy import deepcopy
wxchan's avatar
wxchan committed
8
9
from io import BytesIO

10
11
import numpy as np

wxchan's avatar
wxchan committed
12
from .basic import Booster
13
14
15
from .sklearn import LGBMModel


16
def check_not_tuple_of_2_elements(obj, obj_name='obj'):
wxchan's avatar
wxchan committed
17
    """check object is not tuple or does not have 2 elements"""
18
19
    if not isinstance(obj, tuple) or len(obj) != 2:
        raise TypeError('%s must be a tuple of 2 elements.' % obj_name)
wxchan's avatar
wxchan committed
20
21


22
23
24
25
def plot_importance(booster, ax=None, height=0.2,
                    xlim=None, ylim=None, title='Feature importance',
                    xlabel='Feature importance', ylabel='Features',
                    importance_type='split', max_num_features=None,
wxchan's avatar
wxchan committed
26
                    ignore_zero=True, figsize=None, grid=True, **kwargs):
27
28
29
30
    """Plot model feature importances.

    Parameters
    ----------
wxchan's avatar
wxchan committed
31
32
    booster : Booster or LGBMModel
        Booster or LGBMModel instance
33
34
35
36
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    height : float
        Bar height, passed to ax.barh()
wxchan's avatar
wxchan committed
37
    xlim : tuple of 2 elements
38
        Tuple passed to axes.xlim()
wxchan's avatar
wxchan committed
39
    ylim : tuple of 2 elements
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        Tuple passed to axes.ylim()
    title : str
        Axes title. Pass None to disable.
    xlabel : str
        X axis title label. Pass None to disable.
    ylabel : str
        Y axis title label. Pass None to disable.
    importance_type : str
        How the importance is calculated: "split" or "gain"
        "split" is the number of times a feature is used in a model
        "gain" is the total gain of splits which use the feature
    max_num_features : int
        Max number of top features displayed on plot.
        If None or smaller than 1, all features will be displayed.
    ignore_zero : bool
        Ignore features with zero importance
wxchan's avatar
wxchan committed
56
57
    figsize : tuple of 2 elements
        Figure size
58
59
60
61
62
63
64
65
66
67
68
69
    grid : bool
        Whether add grid for axes
    **kwargs :
        Other keywords passed to ax.barh()

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
wxchan's avatar
wxchan committed
70
        raise ImportError('You must install matplotlib to plot importance.')
71
72

    if isinstance(booster, LGBMModel):
wxchan's avatar
wxchan committed
73
74
75
76
77
78
        booster = booster.booster_
    elif not isinstance(booster, Booster):
        raise TypeError('booster must be Booster or LGBMModel.')

    importance = booster.feature_importance(importance_type=importance_type)
    feature_name = booster.feature_name()
79
80

    if not len(importance):
wxchan's avatar
wxchan committed
81
        raise ValueError('Booster feature_importances are empty.')
82

wxchan's avatar
wxchan committed
83
    tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
84
85
86
87
88
89
90
    if ignore_zero:
        tuples = [x for x in tuples if x[1] > 0]
    if max_num_features is not None and max_num_features > 0:
        tuples = tuples[-max_num_features:]
    labels, values = zip(*tuples)

    if ax is None:
91
92
        if figsize is not None:
            check_not_tuple_of_2_elements(figsize, 'figsize')
wxchan's avatar
wxchan committed
93
        _, ax = plt.subplots(1, 1, figsize=figsize)
94
95
96
97
98
99
100
101
102
103
104

    ylocs = np.arange(len(values))
    ax.barh(ylocs, values, align='center', height=height, **kwargs)

    for x, y in zip(values, ylocs):
        ax.text(x + 1, y, x, va='center')

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
105
        check_not_tuple_of_2_elements(xlim, 'xlim')
106
107
108
109
110
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
111
        check_not_tuple_of_2_elements(ylim, 'ylim')
112
113
114
115
116
117
118
119
120
121
122
123
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax
wxchan's avatar
wxchan committed
124
125


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def plot_metric(booster, metric=None, dataset_names=None,
                ax=None, xlim=None, ylim=None,
                title='Metric during training',
                xlabel='Iterations', ylabel='auto',
                figsize=None, grid=True):
    """Plot one metric during training.

    Parameters
    ----------
    booster : dict or LGBMModel
        Evals_result recorded by lightgbm.train() or LGBMModel instance
    metric : str or None
        The metric name to plot.
        Only one metric supported because different metrics have various scales.
        Pass None to pick `first` one (according to dict hashcode).
    dataset_names : None or list of str
        List of the dataset names to plot.
        Pass None to plot all datasets.
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    xlim : tuple of 2 elements
        Tuple passed to axes.xlim()
    ylim : tuple of 2 elements
        Tuple passed to axes.ylim()
    title : str
        Axes title. Pass None to disable.
    xlabel : str
        X axis title label. Pass None to disable.
    ylabel : str
        Y axis title label. Pass None to disable. Pass 'auto' to use `metric`.
    figsize : tuple of 2 elements
        Figure size
    grid : bool
        Whether add grid for axes

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        raise ImportError('You must install matplotlib to plot metric.')

    if isinstance(booster, LGBMModel):
        eval_results = deepcopy(booster.evals_result_)
    elif isinstance(booster, dict):
        eval_results = deepcopy(booster)
    else:
        raise TypeError('booster must be dict or LGBMModel.')

    num_data = len(eval_results)

    if not num_data:
        raise ValueError('eval results cannot be empty.')

    if ax is None:
        if figsize is not None:
            check_not_tuple_of_2_elements(figsize, 'figsize')
        _, ax = plt.subplots(1, 1, figsize=figsize)

    if dataset_names is None:
        dataset_names = iter(eval_results.keys())
    elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
        raise ValueError('dataset_names should be iterable and cannot be empty')
    else:
        dataset_names = iter(dataset_names)

    name = next(dataset_names)  # take one as sample
    metrics_for_one = eval_results[name]
    num_metric = len(metrics_for_one)
    if metric is None:
        if num_metric > 1:
wxchan's avatar
wxchan committed
199
200
            msg = """more than one metric available, picking one to plot."""
            warnings.warn(msg, stacklevel=2)
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        metric, results = metrics_for_one.popitem()
    else:
        if metric not in metrics_for_one:
            raise KeyError('No given metric in eval results.')
        results = metrics_for_one[metric]
    num_iteration, max_result, min_result = len(results), max(results), min(results)
    x_ = range(num_iteration)
    ax.plot(x_, results, label=name)

    for name in dataset_names:
        metrics_for_one = eval_results[name]
        results = metrics_for_one[metric]
        max_result, min_result = max(max(results), max_result), min(min(results), min_result)
        ax.plot(x_, results, label=name)

    ax.legend(loc='best')

    if xlim is not None:
        check_not_tuple_of_2_elements(xlim, 'xlim')
    else:
        xlim = (0, num_iteration)
    ax.set_xlim(xlim)

    if ylim is not None:
        check_not_tuple_of_2_elements(ylim, 'ylim')
    else:
        range_result = max_result - min_result
        ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
    ax.set_ylim(ylim)

    if ylabel == 'auto':
        ylabel = metric

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


244
def _to_graphviz(tree_info, show_info, feature_names,
245
246
247
248
249
250
251
252
                 name=None, comment=None, filename=None, directory=None,
                 format=None, engine=None, encoding=None, graph_attr=None,
                 node_attr=None, edge_attr=None, body=None, strict=False):
    """Convert specified tree to graphviz instance.

    See:
      - http://graphviz.readthedocs.io/en/stable/api.html#digraph
    """
253
254
255
256
    try:
        from graphviz import Digraph
    except ImportError:
        raise ImportError('You must install graphviz to plot tree.')
wxchan's avatar
wxchan committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270

    def add(root, parent=None, decision=None):
        """recursively add node or edge"""
        if 'split_index' in root:  # non-leaf
            name = 'split' + str(root['split_index'])
            if feature_names is not None:
                label = 'split_feature_name:' + str(feature_names[root['split_feature']])
            else:
                label = 'split_feature_index:' + str(root['split_feature'])
            label += '\nthreshold:' + str(root['threshold'])
            for info in show_info:
                if info in {'split_gain', 'internal_value', 'internal_count'}:
                    label += '\n' + info + ':' + str(root[info])
            graph.node(name, label=label)
271
272
273
274
275
276
            if root['decision_type'] == 'no_greater':
                l_dec, r_dec = '<=', '>'
            elif root['decision_type'] == 'is':
                l_dec, r_dec = 'is', "isn't"
            else:
                raise ValueError('Invalid decision type in tree model.')
wxchan's avatar
wxchan committed
277
278
279
280
281
282
283
284
285
286
287
            add(root['left_child'], name, l_dec)
            add(root['right_child'], name, r_dec)
        else:  # leaf
            name = 'left' + str(root['leaf_index'])
            label = 'leaf_value:' + str(root['leaf_value'])
            if 'leaf_count' in show_info:
                label += '\nleaf_count:' + str(root['leaf_count'])
            graph.node(name, label=label)
        if parent is not None:
            graph.edge(parent, name, decision)

288
289
290
    graph = Digraph(name=name, comment=comment, filename=filename, directory=directory,
                    format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
                    node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
wxchan's avatar
wxchan committed
291
    add(tree_info['tree_structure'])
292
293
294
295

    return graph


296
297
298
299
def create_tree_digraph(booster, tree_index=0, show_info=None,
                        name=None, comment=None, filename=None, directory=None,
                        format=None, engine=None, encoding=None, graph_attr=None,
                        node_attr=None, edge_attr=None, body=None, strict=False):
300
301
    """Create a digraph of specified tree.

302
303
304
    See:
      - http://graphviz.readthedocs.io/en/stable/api.html#digraph

305
306
307
308
309
310
    Parameters
    ----------
    booster : Booster, LGBMModel
        Booster or LGBMModel instance.
    tree_index : int, default 0
        Specify tree index of target tree.
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    show_info : list
        Information shows on nodes.
        options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.
    name : str
        Graph name used in the source code.
    comment : str
        Comment added to the first line of the source.
    filename : str
        Filename for saving the source (defaults to name + '.gv').
    directory : str
        (Sub)directory for source saving and rendering.
    format : str
        Rendering output format ('pdf', 'png', ...).
    engine : str
        Layout command used ('dot', 'neato', ...).
    encoding : str
        Encoding for saving the source.
328
329
330
331
332
333
    graph_attr : dict
        Mapping of (attribute, value) pairs for the graph.
    node_attr : dict
        Mapping of (attribute, value) pairs set for all nodes.
    edge_attr : dict
        Mapping of (attribute, value) pairs set for all edges.
334
335
336
337
    body : list of str
        Iterable of lines to add to the graph body.
    strict : bool
        Iterable of lines to add to the graph body.
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363

    Returns
    -------
    graph : graphviz Digraph
    """
    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
        raise TypeError('booster must be Booster or LGBMModel.')

    model = booster.dump_model()
    tree_infos = model['tree_info']
    if 'feature_names' in model:
        feature_names = model['feature_names']
    else:
        feature_names = None

    if tree_index < len(tree_infos):
        tree_info = tree_infos[tree_index]
    else:
        raise IndexError('tree_index is out of range.')

    if show_info is None:
        show_info = []

    graph = _to_graphviz(tree_info, show_info, feature_names,
364
365
366
                         name=name, comment=comment, filename=filename, directory=directory,
                         format=format, engine=engine, encoding=encoding, graph_attr=graph_attr,
                         node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict)
367

wxchan's avatar
wxchan committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    return graph


def plot_tree(booster, ax=None, tree_index=0, figsize=None,
              graph_attr=None, node_attr=None, edge_attr=None,
              show_info=None):
    """Plot specified tree.

    Parameters
    ----------
    booster : Booster, LGBMModel
        Booster or LGBMModel instance.
    ax : matplotlib Axes
        Target axes instance. If None, new figure and axes will be created.
    tree_index : int, default 0
        Specify tree index of target tree.
384
    figsize : tuple of 2 elements
wxchan's avatar
wxchan committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        Figure size.
    graph_attr : dict
        Mapping of (attribute, value) pairs for the graph.
    node_attr : dict
        Mapping of (attribute, value) pairs set for all nodes.
    edge_attr : dict
        Mapping of (attribute, value) pairs set for all edges.
    show_info : list
        Information shows on nodes.
        options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'.

    Returns
    -------
    ax : matplotlib Axes
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib.image as image
    except ImportError:
        raise ImportError('You must install matplotlib to plot tree.')

    if ax is None:
407
408
        if figsize is not None:
            check_not_tuple_of_2_elements(figsize, 'figsize')
wxchan's avatar
wxchan committed
409
410
        _, ax = plt.subplots(1, 1, figsize=figsize)

411
412
413
414
415
416
417
418
    graph = create_tree_digraph(
        booster=booster,
        tree_index=tree_index,
        graph_attr=graph_attr,
        node_attr=node_attr,
        edge_attr=edge_attr,
        show_info=show_info
    )
wxchan's avatar
wxchan committed
419
420

    s = BytesIO()
421
    s.write(graph.pipe(format='png'))
wxchan's avatar
wxchan committed
422
423
424
425
426
427
    s.seek(0)
    img = image.imread(s)

    ax.imshow(img)
    ax.axis('off')
    return ax