plotting.py 31.8 KB
Newer Older
1
# coding: utf-8
2
"""Plotting library."""
3

4
import math
5
from copy import deepcopy
wxchan's avatar
wxchan committed
6
from io import BytesIO
7
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
wxchan's avatar
wxchan committed
8

9
10
import numpy as np

11
from .basic import Booster, _data_from_pandas, _is_zero, _log_warning, _MissingType
12
from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, pd_DataFrame
13
14
from .sklearn import LGBMModel

15
__all__ = [
16
17
18
19
20
    "create_tree_digraph",
    "plot_importance",
    "plot_metric",
    "plot_split_value_histogram",
    "plot_tree",
21
22
]

23
24
25
if TYPE_CHECKING:
    import matplotlib

26

27
def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None:
28
    """Check object is not tuple or does not have 2 elements."""
29
    if not isinstance(obj, tuple) or len(obj) != 2:
30
        raise TypeError(f"{obj_name} must be a tuple of 2 elements.")
wxchan's avatar
wxchan committed
31
32


33
def _float2str(value: float, precision: Optional[int]) -> str:
34
    return f"{value:.{precision}f}" if precision is not None and not isinstance(value, str) else str(value)
35
36


37
38
def plot_importance(
    booster: Union[Booster, LGBMModel],
39
    ax: "Optional[matplotlib.axes.Axes]" = None,
40
41
42
    height: float = 0.2,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
43
44
45
46
    title: Optional[str] = "Feature importance",
    xlabel: Optional[str] = "Feature importance",
    ylabel: Optional[str] = "Features",
    importance_type: str = "auto",
47
48
49
50
51
52
    max_num_features: Optional[int] = None,
    ignore_zero: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    grid: bool = True,
    precision: Optional[int] = 3,
53
    **kwargs: Any,
54
) -> Any:
55
    """Plot model's feature importances.
56
57
58

    Parameters
    ----------
wxchan's avatar
wxchan committed
59
    booster : Booster or LGBMModel
60
61
62
63
64
65
66
67
68
69
        Booster or LGBMModel instance which feature importance should be plotted.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    height : float, optional (default=0.2)
        Bar height, passed to ``ax.barh()``.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
70
    title : str or None, optional (default="Feature importance")
71
72
        Axes title.
        If None, title is disabled.
73
    xlabel : str or None, optional (default="Feature importance")
74
75
        X-axis title label.
        If None, title is disabled.
76
        @importance_type@ placeholder can be used, and it will be replaced with the value of ``importance_type`` parameter.
77
    ylabel : str or None, optional (default="Features")
78
79
        Y-axis title label.
        If None, title is disabled.
80
    importance_type : str, optional (default="auto")
81
        How the importance is calculated.
82
        If "auto", if ``booster`` parameter is LGBMModel, ``booster.importance_type`` attribute is used; "split" otherwise.
83
84
85
        If "split", result contains numbers of times the feature is used in a model.
        If "gain", result contains total gains of splits which use the feature.
    max_num_features : int or None, optional (default=None)
86
        Max number of top features displayed on plot.
87
88
89
90
91
        If None or <1, all features will be displayed.
    ignore_zero : bool, optional (default=True)
        Whether to ignore features with zero importance.
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
92
93
    dpi : int or None, optional (default=None)
        Resolution of the figure.
94
95
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
96
    precision : int or None, optional (default=3)
97
        Used to restrict the display of floating point values to a certain precision.
98
    **kwargs
99
        Other parameters passed to ``ax.barh()``.
100
101
102

    Returns
    -------
103
104
    ax : matplotlib.axes.Axes
        The plot with model's feature importances.
105
    """
106
    if MATPLOTLIB_INSTALLED:
107
        import matplotlib.pyplot as plt
108
    else:
109
        raise ImportError("You must install matplotlib and restart your session to plot importance.")
110
111

    if isinstance(booster, LGBMModel):
112
113
        if importance_type == "auto":
            importance_type = booster.importance_type
wxchan's avatar
wxchan committed
114
        booster = booster.booster_
115
116
117
118
    elif isinstance(booster, Booster):
        if importance_type == "auto":
            importance_type = "split"
    else:
119
        raise TypeError("booster must be Booster or LGBMModel.")
wxchan's avatar
wxchan committed
120
121
122

    importance = booster.feature_importance(importance_type=importance_type)
    feature_name = booster.feature_name()
123
124

    if not len(importance):
125
        raise ValueError("Booster's feature_importance is empty.")
126

127
    tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
128
129
130
131
    if ignore_zero:
        tuples = [x for x in tuples if x[1] > 0]
    if max_num_features is not None and max_num_features > 0:
        tuples = tuples[-max_num_features:]
132
    labels, values = zip(*tuples)
133
134

    if ax is None:
135
        if figsize is not None:
136
            _check_not_tuple_of_2_elements(figsize, "figsize")
137
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
138
139

    ylocs = np.arange(len(values))
140
    ax.barh(ylocs, values, align="center", height=height, **kwargs)
141

142
    for x, y in zip(values, ylocs):
143
        ax.text(x + 1, y, _float2str(x, precision) if importance_type == "gain" else x, va="center")
144
145
146
147
148

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
149
        _check_not_tuple_of_2_elements(xlim, "xlim")
150
151
152
153
154
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
155
        _check_not_tuple_of_2_elements(ylim, "ylim")
156
157
158
159
160
161
162
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
163
        xlabel = xlabel.replace("@importance_type@", importance_type)
164
165
166
167
168
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax
wxchan's avatar
wxchan committed
169
170


171
172
173
174
def plot_split_value_histogram(
    booster: Union[Booster, LGBMModel],
    feature: Union[int, str],
    bins: Union[int, str, None] = None,
175
    ax: "Optional[matplotlib.axes.Axes]" = None,
176
177
178
    width_coef: float = 0.8,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
179
180
181
    title: Optional[str] = "Split value histogram for feature with @index/name@ @feature@",
    xlabel: Optional[str] = "Feature split value",
    ylabel: Optional[str] = "Count",
182
183
184
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    grid: bool = True,
185
    **kwargs: Any,
186
) -> Any:
187
188
189
190
191
192
    """Plot split value histogram for the specified feature of the model.

    Parameters
    ----------
    booster : Booster or LGBMModel
        Booster or LGBMModel instance of which feature split value histogram should be plotted.
193
    feature : int or str
194
195
        The feature name or index the histogram is plotted for.
        If int, interpreted as index.
196
197
        If str, interpreted as name.
    bins : int, str or None, optional (default=None)
198
199
        The maximum number of bins.
        If None, the number of bins equals number of unique split values.
200
        If str, it should be one from the list of the supported values by ``numpy.histogram()`` function.
201
202
203
204
205
206
207
208
209
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    width_coef : float, optional (default=0.8)
        Coefficient for histogram bar width.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
210
    title : str or None, optional (default="Split value histogram for feature with @index/name@ @feature@")
211
212
213
214
215
        Axes title.
        If None, title is disabled.
        @feature@ placeholder can be used, and it will be replaced with the value of ``feature`` parameter.
        @index/name@ placeholder can be used,
        and it will be replaced with ``index`` word in case of ``int`` type ``feature`` parameter
216
217
        or ``name`` word in case of ``str`` type ``feature`` parameter.
    xlabel : str or None, optional (default="Feature split value")
218
219
        X-axis title label.
        If None, title is disabled.
220
    ylabel : str or None, optional (default="Count")
221
222
223
224
        Y-axis title label.
        If None, title is disabled.
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
225
226
    dpi : int or None, optional (default=None)
        Resolution of the figure.
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
    **kwargs
        Other parameters passed to ``ax.bar()``.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The plot with specified model's feature split value histogram.
    """
    if MATPLOTLIB_INSTALLED:
        import matplotlib.pyplot as plt
        from matplotlib.ticker import MaxNLocator
    else:
241
        raise ImportError("You must install matplotlib and restart your session to plot split value histogram.")
242
243
244
245

    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
246
        raise TypeError("booster must be Booster or LGBMModel.")
247

248
    hist, split_bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
249
    if np.count_nonzero(hist) == 0:
250
        raise ValueError(f"Cannot plot split value histogram, because feature {feature} was not used in splitting")
251
252
    width = width_coef * (split_bins[1] - split_bins[0])
    centred = (split_bins[:-1] + split_bins[1:]) / 2
253
254
255

    if ax is None:
        if figsize is not None:
256
            _check_not_tuple_of_2_elements(figsize, "figsize")
257
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
258

259
    ax.bar(centred, hist, align="center", width=width, **kwargs)
260
261

    if xlim is not None:
262
        _check_not_tuple_of_2_elements(xlim, "xlim")
263
    else:
264
265
        range_result = split_bins[-1] - split_bins[0]
        xlim = (split_bins[0] - range_result * 0.2, split_bins[-1] + range_result * 0.2)
266
267
268
269
    ax.set_xlim(xlim)

    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    if ylim is not None:
270
        _check_not_tuple_of_2_elements(ylim, "ylim")
271
272
273
274
275
    else:
        ylim = (0, max(hist) * 1.1)
    ax.set_ylim(ylim)

    if title is not None:
276
277
        title = title.replace("@feature@", str(feature))
        title = title.replace("@index/name@", ("name" if isinstance(feature, str) else "index"))
278
279
280
281
282
283
284
285
286
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


287
288
289
290
def plot_metric(
    booster: Union[Dict, LGBMModel],
    metric: Optional[str] = None,
    dataset_names: Optional[List[str]] = None,
291
    ax: "Optional[matplotlib.axes.Axes]" = None,
292
293
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
294
295
296
    title: Optional[str] = "Metric during training",
    xlabel: Optional[str] = "Iterations",
    ylabel: Optional[str] = "@metric@",
297
298
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
299
    grid: bool = True,
300
) -> Any:
301
302
303
304
305
    """Plot one metric during training.

    Parameters
    ----------
    booster : dict or LGBMModel
306
        Dictionary returned from ``lightgbm.train()`` or LGBMModel instance.
307
    metric : str or None, optional (default=None)
308
309
        The metric name to plot.
        Only one metric supported because different metrics have various scales.
310
        If None, first metric picked from dictionary (according to hashcode).
311
    dataset_names : list of str, or None, optional (default=None)
312
313
314
315
316
317
318
319
320
        List of the dataset names which are used to calculate metric to plot.
        If None, all datasets are used.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
321
    title : str or None, optional (default="Metric during training")
322
323
        Axes title.
        If None, title is disabled.
324
    xlabel : str or None, optional (default="Iterations")
325
326
        X-axis title label.
        If None, title is disabled.
327
    ylabel : str or None, optional (default="@metric@")
328
329
330
        Y-axis title label.
        If 'auto', metric name is used.
        If None, title is disabled.
331
        @metric@ placeholder can be used, and it will be replaced with metric name.
332
333
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
334
335
    dpi : int or None, optional (default=None)
        Resolution of the figure.
336
337
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
338
339
340

    Returns
    -------
341
342
    ax : matplotlib.axes.Axes
        The plot with metric's history over the training.
343
    """
344
    if MATPLOTLIB_INSTALLED:
345
        import matplotlib.pyplot as plt
346
    else:
347
        raise ImportError("You must install matplotlib and restart your session to plot metric.")
348
349
350
351
352

    if isinstance(booster, LGBMModel):
        eval_results = deepcopy(booster.evals_result_)
    elif isinstance(booster, dict):
        eval_results = deepcopy(booster)
353
    elif isinstance(booster, Booster):
354
355
356
        raise TypeError(
            "booster must be dict or LGBMModel. To use plot_metric with Booster type, first record the metrics using record_evaluation callback then pass that to plot_metric as argument `booster`"
        )
357
    else:
358
        raise TypeError("booster must be dict or LGBMModel.")
359
360
361
362

    num_data = len(eval_results)

    if not num_data:
363
        raise ValueError("eval results cannot be empty.")
364
365
366

    if ax is None:
        if figsize is not None:
367
            _check_not_tuple_of_2_elements(figsize, "figsize")
368
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
369
370

    if dataset_names is None:
371
        dataset_names_iter = iter(eval_results.keys())
372
    elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
373
        raise ValueError("dataset_names should be iterable and cannot be empty")
374
    else:
375
        dataset_names_iter = iter(dataset_names)
376

377
    name = next(dataset_names_iter)  # take one as sample
378
379
380
381
    metrics_for_one = eval_results[name]
    num_metric = len(metrics_for_one)
    if metric is None:
        if num_metric > 1:
382
            _log_warning("More than one metric available, picking one to plot.")
383
384
385
        metric, results = metrics_for_one.popitem()
    else:
        if metric not in metrics_for_one:
386
            raise KeyError("No given metric in eval results.")
387
        results = metrics_for_one[metric]
388
389
390
    num_iteration = len(results)
    max_result = max(results)
    min_result = min(results)
391
    x_ = range(num_iteration)
392
393
    ax.plot(x_, results, label=name)

394
    for name in dataset_names_iter:
395
396
        metrics_for_one = eval_results[name]
        results = metrics_for_one[metric]
397
398
        max_result = max(*results, max_result)
        min_result = min(*results, min_result)
399
400
        ax.plot(x_, results, label=name)

401
    ax.legend(loc="best")
402
403

    if xlim is not None:
404
        _check_not_tuple_of_2_elements(xlim, "xlim")
405
406
407
408
409
    else:
        xlim = (0, num_iteration)
    ax.set_xlim(xlim)

    if ylim is not None:
410
        _check_not_tuple_of_2_elements(ylim, "ylim")
411
412
413
414
415
416
417
418
419
420
    else:
        range_result = max_result - min_result
        ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
421
        ylabel = ylabel.replace("@metric@", metric)
422
423
424
425
426
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


427
428
429
430
431
432
433
434
435
def _determine_direction_for_numeric_split(
    fval: float,
    threshold: float,
    missing_type_str: str,
    default_left: bool,
) -> str:
    missing_type = _MissingType(missing_type_str)
    if math.isnan(fval) and missing_type != _MissingType.NAN:
        fval = 0.0
436
437
438
439
    if (missing_type == _MissingType.ZERO and _is_zero(fval)) or (
        missing_type == _MissingType.NAN and math.isnan(fval)
    ):
        direction = "left" if default_left else "right"
440
    else:
441
        direction = "left" if fval <= threshold else "right"
442
443
444
445
446
    return direction


def _determine_direction_for_categorical_split(fval: float, thresholds: str) -> str:
    if math.isnan(fval) or int(fval) < 0:
447
448
449
        return "right"
    int_thresholds = {int(t) for t in thresholds.split("||")}
    return "left" if int(fval) in int_thresholds else "right"
450
451


452
453
454
455
def _to_graphviz(
    tree_info: Dict[str, Any],
    show_info: List[str],
    feature_names: Union[List[str], None],
456
457
458
459
    precision: Optional[int],
    orientation: str,
    constraints: Optional[List[int]],
    example_case: Optional[Union[np.ndarray, pd_DataFrame]],
460
    max_category_values: int,
461
    **kwargs: Any,
462
) -> Any:
463
464
465
    """Convert specified tree to graphviz instance.

    See:
466
      - https://graphviz.readthedocs.io/en/stable/api.html#digraph
467
    """
468
    if GRAPHVIZ_INSTALLED:
469
        from graphviz import Digraph
470
    else:
471
        raise ImportError("You must install graphviz and restart your session to plot tree.")
wxchan's avatar
wxchan committed
472

473
    def add(
474
        root: Dict[str, Any], total_count: int, parent: Optional[str], decision: Optional[str], highlight: bool
475
    ) -> None:
476
        """Recursively add node or edge."""
477
478
        fillcolor = "white"
        style = ""
479
        tooltip = None
480
        if highlight:
481
482
            color = "blue"
            penwidth = "3"
483
        else:
484
485
486
            color = "black"
            penwidth = "1"
        if "split_index" in root:  # non-leaf
487
            shape = "rectangle"
488
489
490
491
            l_dec = "yes"
            r_dec = "no"
            threshold = root["threshold"]
            if root["decision_type"] == "<=":
492
                operator = "&#8804;"
493
            elif root["decision_type"] == "==":
494
                operator = "="
495
            else:
496
                raise ValueError("Invalid decision type in tree model.")
497
            name = f"split{root['split_index']}"
498
            split_feature = root["split_feature"]
499
            if feature_names is not None:
500
                label = f"<B>{feature_names[split_feature]}</B> {operator}"
501
            else:
502
503
504
                label = f"feature <B>{split_feature}</B> {operator} "
            direction = None
            if example_case is not None:
505
                if root["decision_type"] == "==":
506
                    direction = _determine_direction_for_categorical_split(
507
                        fval=example_case[split_feature], thresholds=root["threshold"]
508
                    )
509
510
                else:
                    direction = _determine_direction_for_numeric_split(
511
                        fval=example_case[split_feature],
512
513
514
                        threshold=root["threshold"],
                        missing_type_str=root["missing_type"],
                        default_left=root["default_left"],
515
                    )
516
517
            if root["decision_type"] == "==":
                category_values = root["threshold"].split("||")
518
                if len(category_values) > max_category_values:
519
520
                    tooltip = root["threshold"]
                    threshold = "||".join(category_values[:2]) + "||...||" + category_values[-1]
521
522

            label += f"<B>{_float2str(threshold, precision)}</B>"
523
            for info in ["split_gain", "internal_value", "internal_weight", "internal_count", "data_percentage"]:
524
                if info in show_info:
525
526
                    output = info.split("_")[-1]
                    if info in {"split_gain", "internal_value", "internal_weight"}:
527
                        label += f"<br/>{_float2str(root[info], precision)} {output}"
528
                    elif info == "internal_count":
529
                        label += f"<br/>{output}: {root[info]}"
530
                    elif info == "data_percentage":
531
                        label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data"
532
533

            if constraints:
534
                if constraints[root["split_feature"]] == 1:
535
                    fillcolor = "#ddffdd"  # light green
536
                if constraints[root["split_feature"]] == -1:
537
538
                    fillcolor = "#ffdddd"  # light red
                style = "filled"
539
            label = f"<{label}>"
540
            add(
541
                root=root["left_child"],
542
543
544
                total_count=total_count,
                parent=name,
                decision=l_dec,
545
                highlight=highlight and direction == "left",
546
547
            )
            add(
548
                root=root["right_child"],
549
550
551
                total_count=total_count,
                parent=name,
                decision=r_dec,
552
                highlight=highlight and direction == "right",
553
            )
wxchan's avatar
wxchan committed
554
        else:  # leaf
555
            shape = "ellipse"
556
557
558
            name = f"leaf{root['leaf_index']}"
            label = f"leaf {root['leaf_index']}: "
            label += f"<B>{_float2str(root['leaf_value'], precision)}</B>"
559
            if "leaf_weight" in show_info:
560
                label += f"<br/>{_float2str(root['leaf_weight'], precision)} weight"
561
            if "leaf_count" in show_info:
562
                label += f"<br/>count: {root['leaf_count']}"
563
            if "data_percentage" in show_info:
564
565
                label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data"
            label = f"<{label}>"
566
567
568
569
570
571
572
573
574
575
        graph.node(
            name,
            label=label,
            shape=shape,
            style=style,
            fillcolor=fillcolor,
            color=color,
            penwidth=penwidth,
            tooltip=tooltip,
        )
wxchan's avatar
wxchan committed
576
        if parent is not None:
577
            graph.edge(parent, name, decision, color=color, penwidth=penwidth)
wxchan's avatar
wxchan committed
578

579
    graph = Digraph(**kwargs)
580
581
    rankdir = "LR" if orientation == "horizontal" else "TB"
    graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
582
    if "internal_count" in tree_info["tree_structure"]:
583
        add(
584
585
            root=tree_info["tree_structure"],
            total_count=tree_info["tree_structure"]["internal_count"],
586
587
            parent=None,
            decision=None,
588
            highlight=example_case is not None,
589
        )
590
    else:
591
        raise Exception("Cannot plot trees with no split")
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610

    if constraints:
        # "#ddffdd" is light green, "#ffdddd" is light red
        legend = """<
            <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">
             <TR>
              <TD COLSPAN="2"><B>Monotone constraints</B></TD>
             </TR>
             <TR>
              <TD>Increasing</TD>
              <TD BGCOLOR="#ddffdd"></TD>
             </TR>
             <TR>
              <TD>Decreasing</TD>
              <TD BGCOLOR="#ffdddd"></TD>
             </TR>
            </TABLE>
           >"""
        graph.node("legend", label=legend, shape="rectangle", color="white")
611
612
613
    return graph


614
615
616
617
618
def create_tree_digraph(
    booster: Union[Booster, LGBMModel],
    tree_index: int = 0,
    show_info: Optional[List[str]] = None,
    precision: Optional[int] = 3,
619
    orientation: str = "horizontal",
620
    example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
621
    max_category_values: int = 10,
622
    **kwargs: Any,
623
) -> Any:
624
    """Create a digraph representation of specified tree.
625

626
627
628
629
630
631
632
633
634
635
    Each node in the graph represents a node in the tree.

    Non-leaf nodes have labels like ``Column_10 <= 875.9``, which means
    "this node splits on the feature named "Column_10", with threshold 875.9".

    Leaf nodes have labels like ``leaf 2: 0.422``, which means "this node is a
    leaf node, and the predicted value for records that fall into this node
    is 0.422". The number (``2``) is an internal unique identifier and doesn't
    have any special meaning.

Nikita Titov's avatar
Nikita Titov committed
636
637
638
639
    .. note::

        For more information please visit
        https://graphviz.readthedocs.io/en/stable/api.html#digraph.
640

641
642
    Parameters
    ----------
643
    booster : Booster or LGBMModel
644
        Booster or LGBMModel instance to be converted.
645
646
    tree_index : int, optional (default=0)
        The index of a target tree to convert.
647
    show_info : list of str, or None, optional (default=None)
648
        What information should be shown in nodes.
649
650
651
652
653
654

            - ``'split_gain'`` : gain from adding this split to the model
            - ``'internal_value'`` : raw predicted value that would be produced by this node if it was a leaf node
            - ``'internal_count'`` : number of records from the training data that fall into this non-leaf node
            - ``'internal_weight'`` : total weight of all nodes that fall into this non-leaf node
            - ``'leaf_count'`` : number of records from the training data that fall into this leaf node
655
            - ``'leaf_weight'`` : total weight (sum of Hessian) of all observations that fall into this leaf node
656
            - ``'data_percentage'`` : percentage of training data that fall into this node
657
    precision : int or None, optional (default=3)
658
        Used to restrict the display of floating point values to a certain precision.
659
    orientation : str, optional (default='horizontal')
660
661
        Orientation of the tree.
        Can be 'horizontal' or 'vertical'.
662
663
664
    example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
        Single row with the same structure as the training data.
        If not None, the plot will highlight the path that sample takes through the tree.
665
666
667

        .. versionadded:: 4.0.0

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    max_category_values : int, optional (default=10)
        The maximum number of category values to display in tree nodes, if the number of thresholds is greater than this value, thresholds will be collapsed and displayed on the label tooltip instead.

        .. warning::

            Consider wrapping the SVG string of the tree graph with ``IPython.display.HTML`` when running on JupyterLab to get the `tooltip <https://graphviz.org/docs/attrs/tooltip>`_ working right.

            Example:

            .. code-block:: python

                from IPython.display import HTML

                graph = lgb.create_tree_digraph(clf, max_category_values=5)
                HTML(graph._repr_image_svg_xml())

684
685
        .. versionadded:: 4.0.0

686
    **kwargs
687
688
        Other parameters passed to ``Digraph`` constructor.
        Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
689
690
691

    Returns
    -------
692
693
    graph : graphviz.Digraph
        The digraph representation of specified tree.
694
695
696
697
    """
    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
698
        raise TypeError("booster must be Booster or LGBMModel.")
699
700

    model = booster.dump_model()
701
702
703
    tree_infos = model["tree_info"]
    feature_names = model.get("feature_names", None)
    monotone_constraints = model.get("monotone_constraints", None)
704

705
706
707
    if tree_index < len(tree_infos):
        tree_info = tree_infos[tree_index]
    else:
708
        raise IndexError("tree_index is out of range.")
709
710
711
712

    if show_info is None:
        show_info = []

713
714
    if example_case is not None:
        if not isinstance(example_case, (np.ndarray, pd_DataFrame)) or example_case.ndim != 2:
715
            raise ValueError("example_case must be a numpy 2-D array or a pandas DataFrame")
716
        if example_case.shape[0] != 1:
717
            raise ValueError("example_case must have a single row.")
718
        if isinstance(example_case, pd_DataFrame):
719
720
            example_case = _data_from_pandas(
                data=example_case,
721
722
                feature_name="auto",
                categorical_feature="auto",
723
                pandas_categorical=booster.pandas_categorical,
724
            )[0]
725
726
        example_case = example_case[0]

727
    return _to_graphviz(
728
729
730
731
732
733
734
        tree_info=tree_info,
        show_info=show_info,
        feature_names=feature_names,
        precision=precision,
        orientation=orientation,
        constraints=monotone_constraints,
        example_case=example_case,
735
        max_category_values=max_category_values,
736
        **kwargs,
737
    )
738

wxchan's avatar
wxchan committed
739

740
741
def plot_tree(
    booster: Union[Booster, LGBMModel],
742
    ax: "Optional[matplotlib.axes.Axes]" = None,
743
744
745
746
747
    tree_index: int = 0,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    show_info: Optional[List[str]] = None,
    precision: Optional[int] = 3,
748
    orientation: str = "horizontal",
749
    example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
750
    **kwargs: Any,
751
) -> Any:
wxchan's avatar
wxchan committed
752
753
    """Plot specified tree.

754
755
756
757
758
759
760
761
762
763
    Each node in the graph represents a node in the tree.

    Non-leaf nodes have labels like ``Column_10 <= 875.9``, which means
    "this node splits on the feature named "Column_10", with threshold 875.9".

    Leaf nodes have labels like ``leaf 2: 0.422``, which means "this node is a
    leaf node, and the predicted value for records that fall into this node
    is 0.422". The number (``2``) is an internal unique identifier and doesn't
    have any special meaning.

Nikita Titov's avatar
Nikita Titov committed
764
765
766
767
    .. note::

        It is preferable to use ``create_tree_digraph()`` because of its lossless quality
        and returned objects can be also rendered and displayed directly inside a Jupyter notebook.
768

wxchan's avatar
wxchan committed
769
770
    Parameters
    ----------
771
772
773
774
775
776
777
778
    booster : Booster or LGBMModel
        Booster or LGBMModel instance to be plotted.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    tree_index : int, optional (default=0)
        The index of a target tree to plot.
    figsize : tuple of 2 elements or None, optional (default=None)
wxchan's avatar
wxchan committed
779
        Figure size.
780
781
    dpi : int or None, optional (default=None)
        Resolution of the figure.
782
    show_info : list of str, or None, optional (default=None)
783
        What information should be shown in nodes.
784
785
786
787
788
789

            - ``'split_gain'`` : gain from adding this split to the model
            - ``'internal_value'`` : raw predicted value that would be produced by this node if it was a leaf node
            - ``'internal_count'`` : number of records from the training data that fall into this non-leaf node
            - ``'internal_weight'`` : total weight of all nodes that fall into this non-leaf node
            - ``'leaf_count'`` : number of records from the training data that fall into this leaf node
790
            - ``'leaf_weight'`` : total weight (sum of Hessian) of all observations that fall into this leaf node
791
            - ``'data_percentage'`` : percentage of training data that fall into this node
792
    precision : int or None, optional (default=3)
793
        Used to restrict the display of floating point values to a certain precision.
794
    orientation : str, optional (default='horizontal')
795
796
        Orientation of the tree.
        Can be 'horizontal' or 'vertical'.
797
798
799
    example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
        Single row with the same structure as the training data.
        If not None, the plot will highlight the path that sample takes through the tree.
800
801
802

        .. versionadded:: 4.0.0

803
    **kwargs
804
805
        Other parameters passed to ``Digraph`` constructor.
        Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
wxchan's avatar
wxchan committed
806
807
808

    Returns
    -------
809
810
    ax : matplotlib.axes.Axes
        The plot with single tree.
wxchan's avatar
wxchan committed
811
    """
812
    if MATPLOTLIB_INSTALLED:
813
        import matplotlib.image
814
        import matplotlib.pyplot as plt
815
    else:
816
        raise ImportError("You must install matplotlib and restart your session to plot tree.")
wxchan's avatar
wxchan committed
817
818

    if ax is None:
819
        if figsize is not None:
820
            _check_not_tuple_of_2_elements(figsize, "figsize")
821
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
wxchan's avatar
wxchan committed
822

823
824
825
826
827
828
829
830
831
    graph = create_tree_digraph(
        booster=booster,
        tree_index=tree_index,
        show_info=show_info,
        precision=precision,
        orientation=orientation,
        example_case=example_case,
        **kwargs,
    )
wxchan's avatar
wxchan committed
832
833

    s = BytesIO()
834
    s.write(graph.pipe(format="png"))
wxchan's avatar
wxchan committed
835
    s.seek(0)
836
    img = matplotlib.image.imread(s)
wxchan's avatar
wxchan committed
837
838

    ax.imshow(img)
839
    ax.axis("off")
wxchan's avatar
wxchan committed
840
    return ax