plotting.py 31.9 KB
Newer Older
1
# coding: utf-8
2
"""Plotting library."""
3

4
import math
5
from copy import deepcopy
wxchan's avatar
wxchan committed
6
from io import BytesIO
7
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
wxchan's avatar
wxchan committed
8

9
10
import numpy as np

11
from .basic import Booster, _data_from_pandas, _is_zero, _log_warning, _MissingType
12
from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, pd_DataFrame
13
14
from .sklearn import LGBMModel

15
__all__ = [
16
17
18
19
20
    "create_tree_digraph",
    "plot_importance",
    "plot_metric",
    "plot_split_value_histogram",
    "plot_tree",
21
22
]

23
24
25
if TYPE_CHECKING:
    import matplotlib

26

27
def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None:
28
    """Check object is not tuple or does not have 2 elements."""
29
    if not isinstance(obj, tuple) or len(obj) != 2:
30
        raise TypeError(f"{obj_name} must be a tuple of 2 elements.")
wxchan's avatar
wxchan committed
31
32


33
def _float2str(value: float, precision: Optional[int]) -> str:
34
    return f"{value:.{precision}f}" if precision is not None and not isinstance(value, str) else str(value)
35
36


37
38
def plot_importance(
    booster: Union[Booster, LGBMModel],
39
    ax: "Optional[matplotlib.axes.Axes]" = None,
40
41
42
    height: float = 0.2,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
43
44
45
46
    title: Optional[str] = "Feature importance",
    xlabel: Optional[str] = "Feature importance",
    ylabel: Optional[str] = "Features",
    importance_type: str = "auto",
47
48
49
50
51
52
    max_num_features: Optional[int] = None,
    ignore_zero: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    grid: bool = True,
    precision: Optional[int] = 3,
53
    **kwargs: Any,
54
) -> Any:
55
    """Plot model's feature importances.
56
57
58

    Parameters
    ----------
wxchan's avatar
wxchan committed
59
    booster : Booster or LGBMModel
60
61
62
63
64
65
66
67
68
69
        Booster or LGBMModel instance which feature importance should be plotted.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    height : float, optional (default=0.2)
        Bar height, passed to ``ax.barh()``.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
70
    title : str or None, optional (default="Feature importance")
71
72
        Axes title.
        If None, title is disabled.
73
    xlabel : str or None, optional (default="Feature importance")
74
75
        X-axis title label.
        If None, title is disabled.
76
        @importance_type@ placeholder can be used, and it will be replaced with the value of ``importance_type`` parameter.
77
    ylabel : str or None, optional (default="Features")
78
79
        Y-axis title label.
        If None, title is disabled.
80
    importance_type : str, optional (default="auto")
81
        How the importance is calculated.
82
        If "auto", if ``booster`` parameter is LGBMModel, ``booster.importance_type`` attribute is used; "split" otherwise.
83
84
85
        If "split", result contains numbers of times the feature is used in a model.
        If "gain", result contains total gains of splits which use the feature.
    max_num_features : int or None, optional (default=None)
86
        Max number of top features displayed on plot.
87
88
89
90
91
        If None or <1, all features will be displayed.
    ignore_zero : bool, optional (default=True)
        Whether to ignore features with zero importance.
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
92
93
    dpi : int or None, optional (default=None)
        Resolution of the figure.
94
95
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
96
    precision : int or None, optional (default=3)
97
        Used to restrict the display of floating point values to a certain precision.
98
    **kwargs
99
        Other parameters passed to ``ax.barh()``.
100
101
102

    Returns
    -------
103
104
    ax : matplotlib.axes.Axes
        The plot with model's feature importances.
105
    """
106
    if MATPLOTLIB_INSTALLED:
107
        import matplotlib.pyplot as plt  # noqa: PLC0415
108
    else:
109
        raise ImportError("You must install matplotlib and restart your session to plot importance.")
110
111

    if isinstance(booster, LGBMModel):
112
113
        if importance_type == "auto":
            importance_type = booster.importance_type
wxchan's avatar
wxchan committed
114
        booster = booster.booster_
115
116
117
118
    elif isinstance(booster, Booster):
        if importance_type == "auto":
            importance_type = "split"
    else:
119
        raise TypeError("booster must be Booster or LGBMModel.")
wxchan's avatar
wxchan committed
120
121
122

    importance = booster.feature_importance(importance_type=importance_type)
    feature_name = booster.feature_name()
123
124

    if not len(importance):
125
        raise ValueError("Booster's feature_importance is empty.")
126

127
    tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
128
129
130
131
    if ignore_zero:
        tuples = [x for x in tuples if x[1] > 0]
    if max_num_features is not None and max_num_features > 0:
        tuples = tuples[-max_num_features:]
132
    labels, values = zip(*tuples)
133
134

    if ax is None:
135
        if figsize is not None:
136
            _check_not_tuple_of_2_elements(figsize, "figsize")
137
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
138
139

    ylocs = np.arange(len(values))
140
    ax.barh(ylocs, values, align="center", height=height, **kwargs)
141

142
    for x, y in zip(values, ylocs):
143
        ax.text(x + 1, y, _float2str(x, precision) if importance_type == "gain" else x, va="center")
144
145
146
147
148

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
149
        _check_not_tuple_of_2_elements(xlim, "xlim")
150
151
152
153
154
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
155
        _check_not_tuple_of_2_elements(ylim, "ylim")
156
157
158
159
160
161
162
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
163
        xlabel = xlabel.replace("@importance_type@", importance_type)
164
165
166
167
168
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax
wxchan's avatar
wxchan committed
169
170


171
172
173
174
def plot_split_value_histogram(
    booster: Union[Booster, LGBMModel],
    feature: Union[int, str],
    bins: Union[int, str, None] = None,
175
    ax: "Optional[matplotlib.axes.Axes]" = None,
176
177
178
    width_coef: float = 0.8,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
179
180
181
    title: Optional[str] = "Split value histogram for feature with @index/name@ @feature@",
    xlabel: Optional[str] = "Feature split value",
    ylabel: Optional[str] = "Count",
182
183
184
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    grid: bool = True,
185
    **kwargs: Any,
186
) -> Any:
187
188
189
190
191
192
    """Plot split value histogram for the specified feature of the model.

    Parameters
    ----------
    booster : Booster or LGBMModel
        Booster or LGBMModel instance of which feature split value histogram should be plotted.
193
    feature : int or str
194
195
        The feature name or index the histogram is plotted for.
        If int, interpreted as index.
196
197
        If str, interpreted as name.
    bins : int, str or None, optional (default=None)
198
199
        The maximum number of bins.
        If None, the number of bins equals number of unique split values.
200
        If str, it should be one from the list of the supported values by ``numpy.histogram()`` function.
201
202
203
204
205
206
207
208
209
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    width_coef : float, optional (default=0.8)
        Coefficient for histogram bar width.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
210
    title : str or None, optional (default="Split value histogram for feature with @index/name@ @feature@")
211
212
213
214
215
        Axes title.
        If None, title is disabled.
        @feature@ placeholder can be used, and it will be replaced with the value of ``feature`` parameter.
        @index/name@ placeholder can be used,
        and it will be replaced with ``index`` word in case of ``int`` type ``feature`` parameter
216
217
        or ``name`` word in case of ``str`` type ``feature`` parameter.
    xlabel : str or None, optional (default="Feature split value")
218
219
        X-axis title label.
        If None, title is disabled.
220
    ylabel : str or None, optional (default="Count")
221
222
223
224
        Y-axis title label.
        If None, title is disabled.
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
225
226
    dpi : int or None, optional (default=None)
        Resolution of the figure.
227
228
229
230
231
232
233
234
235
236
237
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
    **kwargs
        Other parameters passed to ``ax.bar()``.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The plot with specified model's feature split value histogram.
    """
    if MATPLOTLIB_INSTALLED:
238
239
        import matplotlib.pyplot as plt  # noqa: PLC0415
        from matplotlib.ticker import MaxNLocator  # noqa: PLC0415
240
    else:
241
        raise ImportError("You must install matplotlib and restart your session to plot split value histogram.")
242
243
244
245

    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
246
        raise TypeError("booster must be Booster or LGBMModel.")
247

248
    hist, split_bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
249
    if np.count_nonzero(hist) == 0:
250
        raise ValueError(f"Cannot plot split value histogram, because feature {feature} was not used in splitting")
251
252
    width = width_coef * (split_bins[1] - split_bins[0])
    centred = (split_bins[:-1] + split_bins[1:]) / 2
253
254
255

    if ax is None:
        if figsize is not None:
256
            _check_not_tuple_of_2_elements(figsize, "figsize")
257
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
258

259
    ax.bar(centred, hist, align="center", width=width, **kwargs)
260
261

    if xlim is not None:
262
        _check_not_tuple_of_2_elements(xlim, "xlim")
263
    else:
264
265
        range_result = split_bins[-1] - split_bins[0]
        xlim = (split_bins[0] - range_result * 0.2, split_bins[-1] + range_result * 0.2)
266
267
268
269
    ax.set_xlim(xlim)

    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    if ylim is not None:
270
        _check_not_tuple_of_2_elements(ylim, "ylim")
271
272
273
274
275
    else:
        ylim = (0, max(hist) * 1.1)
    ax.set_ylim(ylim)

    if title is not None:
276
277
        title = title.replace("@feature@", str(feature))
        title = title.replace("@index/name@", ("name" if isinstance(feature, str) else "index"))
278
279
280
281
282
283
284
285
286
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


287
288
289
290
def plot_metric(
    booster: Union[Dict, LGBMModel],
    metric: Optional[str] = None,
    dataset_names: Optional[List[str]] = None,
291
    ax: "Optional[matplotlib.axes.Axes]" = None,
292
293
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
294
295
296
    title: Optional[str] = "Metric during training",
    xlabel: Optional[str] = "Iterations",
    ylabel: Optional[str] = "@metric@",
297
298
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
299
    grid: bool = True,
300
) -> Any:
301
302
303
304
305
    """Plot one metric during training.

    Parameters
    ----------
    booster : dict or LGBMModel
306
        Dictionary returned from ``lightgbm.train()`` or LGBMModel instance.
307
    metric : str or None, optional (default=None)
308
309
        The metric name to plot.
        Only one metric supported because different metrics have various scales.
310
        If None, first metric picked from dictionary (according to hashcode).
311
    dataset_names : list of str, or None, optional (default=None)
312
313
314
315
316
317
318
319
320
        List of the dataset names which are used to calculate metric to plot.
        If None, all datasets are used.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
321
    title : str or None, optional (default="Metric during training")
322
323
        Axes title.
        If None, title is disabled.
324
    xlabel : str or None, optional (default="Iterations")
325
326
        X-axis title label.
        If None, title is disabled.
327
    ylabel : str or None, optional (default="@metric@")
328
329
330
        Y-axis title label.
        If 'auto', metric name is used.
        If None, title is disabled.
331
        @metric@ placeholder can be used, and it will be replaced with metric name.
332
333
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
334
335
    dpi : int or None, optional (default=None)
        Resolution of the figure.
336
337
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
338
339
340

    Returns
    -------
341
342
    ax : matplotlib.axes.Axes
        The plot with metric's history over the training.
343
    """
344
    if MATPLOTLIB_INSTALLED:
345
        import matplotlib.pyplot as plt  # noqa: PLC0415
346
    else:
347
        raise ImportError("You must install matplotlib and restart your session to plot metric.")
348
349
350
351
352

    if isinstance(booster, LGBMModel):
        eval_results = deepcopy(booster.evals_result_)
    elif isinstance(booster, dict):
        eval_results = deepcopy(booster)
353
    elif isinstance(booster, Booster):
354
355
356
        raise TypeError(
            "booster must be dict or LGBMModel. To use plot_metric with Booster type, first record the metrics using record_evaluation callback then pass that to plot_metric as argument `booster`"
        )
357
    else:
358
        raise TypeError("booster must be dict or LGBMModel.")
359
360
361
362

    num_data = len(eval_results)

    if not num_data:
363
        raise ValueError("eval results cannot be empty.")
364
365
366

    if ax is None:
        if figsize is not None:
367
            _check_not_tuple_of_2_elements(figsize, "figsize")
368
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
369
370

    if dataset_names is None:
371
        dataset_names_iter = iter(eval_results.keys())
372
    elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
373
        raise ValueError("dataset_names should be iterable and cannot be empty")
374
    else:
375
        dataset_names_iter = iter(dataset_names)
376

377
    name = next(dataset_names_iter)  # take one as sample
378
379
380
381
    metrics_for_one = eval_results[name]
    num_metric = len(metrics_for_one)
    if metric is None:
        if num_metric > 1:
382
            _log_warning("More than one metric available, picking one to plot.")
383
384
385
        metric, results = metrics_for_one.popitem()
    else:
        if metric not in metrics_for_one:
386
            raise KeyError("No given metric in eval results.")
387
        results = metrics_for_one[metric]
388
389
390
    num_iteration = len(results)
    max_result = max(results)
    min_result = min(results)
391
    x_ = range(num_iteration)
392
393
    ax.plot(x_, results, label=name)

394
    for name in dataset_names_iter:
395
396
        metrics_for_one = eval_results[name]
        results = metrics_for_one[metric]
397
398
        max_result = max(*results, max_result)
        min_result = min(*results, min_result)
399
400
        ax.plot(x_, results, label=name)

401
    ax.legend(loc="best")
402
403

    if xlim is not None:
404
        _check_not_tuple_of_2_elements(xlim, "xlim")
405
406
407
408
409
    else:
        xlim = (0, num_iteration)
    ax.set_xlim(xlim)

    if ylim is not None:
410
        _check_not_tuple_of_2_elements(ylim, "ylim")
411
412
413
414
415
416
417
418
419
420
    else:
        range_result = max_result - min_result
        ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
421
        ylabel = ylabel.replace("@metric@", metric)
422
423
424
425
426
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


427
def _determine_direction_for_numeric_split(
428
    *,
429
430
431
432
433
434
435
436
    fval: float,
    threshold: float,
    missing_type_str: str,
    default_left: bool,
) -> str:
    missing_type = _MissingType(missing_type_str)
    if math.isnan(fval) and missing_type != _MissingType.NAN:
        fval = 0.0
437
438
439
440
    if (missing_type == _MissingType.ZERO and _is_zero(fval)) or (
        missing_type == _MissingType.NAN and math.isnan(fval)
    ):
        direction = "left" if default_left else "right"
441
    else:
442
        direction = "left" if fval <= threshold else "right"
443
444
445
446
447
    return direction


def _determine_direction_for_categorical_split(fval: float, thresholds: str) -> str:
    if math.isnan(fval) or int(fval) < 0:
448
449
450
        return "right"
    int_thresholds = {int(t) for t in thresholds.split("||")}
    return "left" if int(fval) in int_thresholds else "right"
451
452


453
def _to_graphviz(
454
    *,
455
456
457
    tree_info: Dict[str, Any],
    show_info: List[str],
    feature_names: Union[List[str], None],
458
459
460
461
    precision: Optional[int],
    orientation: str,
    constraints: Optional[List[int]],
    example_case: Optional[Union[np.ndarray, pd_DataFrame]],
462
    max_category_values: int,
463
    **kwargs: Any,
464
) -> Any:
465
466
467
    """Convert specified tree to graphviz instance.

    See:
468
      - https://graphviz.readthedocs.io/en/stable/api.html#digraph
469
    """
470
    if GRAPHVIZ_INSTALLED:
471
        from graphviz import Digraph  # noqa: PLC0415
472
    else:
473
        raise ImportError("You must install graphviz and restart your session to plot tree.")
wxchan's avatar
wxchan committed
474

475
    def add(
476
        root: Dict[str, Any], total_count: int, parent: Optional[str], decision: Optional[str], highlight: bool
477
    ) -> None:
478
        """Recursively add node or edge."""
479
480
        fillcolor = "white"
        style = ""
481
        tooltip = None
482
        if highlight:
483
484
            color = "blue"
            penwidth = "3"
485
        else:
486
487
488
            color = "black"
            penwidth = "1"
        if "split_index" in root:  # non-leaf
489
            shape = "rectangle"
490
491
492
493
            l_dec = "yes"
            r_dec = "no"
            threshold = root["threshold"]
            if root["decision_type"] == "<=":
494
                operator = "&#8804;"
495
            elif root["decision_type"] == "==":
496
                operator = "="
497
            else:
498
                raise ValueError("Invalid decision type in tree model.")
499
            name = f"split{root['split_index']}"
500
            split_feature = root["split_feature"]
501
            if feature_names is not None:
502
                label = f"<B>{feature_names[split_feature]}</B> {operator}"
503
            else:
504
505
506
                label = f"feature <B>{split_feature}</B> {operator} "
            direction = None
            if example_case is not None:
507
                if root["decision_type"] == "==":
508
                    direction = _determine_direction_for_categorical_split(
509
                        fval=example_case[split_feature], thresholds=root["threshold"]
510
                    )
511
512
                else:
                    direction = _determine_direction_for_numeric_split(
513
                        fval=example_case[split_feature],
514
515
516
                        threshold=root["threshold"],
                        missing_type_str=root["missing_type"],
                        default_left=root["default_left"],
517
                    )
518
519
            if root["decision_type"] == "==":
                category_values = root["threshold"].split("||")
520
                if len(category_values) > max_category_values:
521
522
                    tooltip = root["threshold"]
                    threshold = "||".join(category_values[:2]) + "||...||" + category_values[-1]
523
524

            label += f"<B>{_float2str(threshold, precision)}</B>"
525
            for info in ["split_gain", "internal_value", "internal_weight", "internal_count", "data_percentage"]:
526
                if info in show_info:
527
528
                    output = info.split("_")[-1]
                    if info in {"split_gain", "internal_value", "internal_weight"}:
529
                        label += f"<br/>{_float2str(root[info], precision)} {output}"
530
                    elif info == "internal_count":
531
                        label += f"<br/>{output}: {root[info]}"
532
                    elif info == "data_percentage":
533
                        label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data"
534
535

            if constraints:
536
                if constraints[root["split_feature"]] == 1:
537
                    fillcolor = "#ddffdd"  # light green
538
                if constraints[root["split_feature"]] == -1:
539
540
                    fillcolor = "#ffdddd"  # light red
                style = "filled"
541
            label = f"<{label}>"
542
            add(
543
                root=root["left_child"],
544
545
546
                total_count=total_count,
                parent=name,
                decision=l_dec,
547
                highlight=highlight and direction == "left",
548
549
            )
            add(
550
                root=root["right_child"],
551
552
553
                total_count=total_count,
                parent=name,
                decision=r_dec,
554
                highlight=highlight and direction == "right",
555
            )
wxchan's avatar
wxchan committed
556
        else:  # leaf
557
            shape = "ellipse"
558
559
560
            name = f"leaf{root['leaf_index']}"
            label = f"leaf {root['leaf_index']}: "
            label += f"<B>{_float2str(root['leaf_value'], precision)}</B>"
561
            if "leaf_weight" in show_info:
562
                label += f"<br/>{_float2str(root['leaf_weight'], precision)} weight"
563
            if "leaf_count" in show_info:
564
                label += f"<br/>count: {root['leaf_count']}"
565
            if "data_percentage" in show_info:
566
567
                label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data"
            label = f"<{label}>"
568
569
570
571
572
573
574
575
576
577
        graph.node(
            name,
            label=label,
            shape=shape,
            style=style,
            fillcolor=fillcolor,
            color=color,
            penwidth=penwidth,
            tooltip=tooltip,
        )
wxchan's avatar
wxchan committed
578
        if parent is not None:
579
            graph.edge(parent, name, decision, color=color, penwidth=penwidth)
wxchan's avatar
wxchan committed
580

581
    graph = Digraph(**kwargs)
582
583
    rankdir = "LR" if orientation == "horizontal" else "TB"
    graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
584
    if "internal_count" in tree_info["tree_structure"]:
585
        add(
586
587
            root=tree_info["tree_structure"],
            total_count=tree_info["tree_structure"]["internal_count"],
588
589
            parent=None,
            decision=None,
590
            highlight=example_case is not None,
591
        )
592
    else:
593
        raise Exception("Cannot plot trees with no split")
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

    if constraints:
        # "#ddffdd" is light green, "#ffdddd" is light red
        legend = """<
            <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">
             <TR>
              <TD COLSPAN="2"><B>Monotone constraints</B></TD>
             </TR>
             <TR>
              <TD>Increasing</TD>
              <TD BGCOLOR="#ddffdd"></TD>
             </TR>
             <TR>
              <TD>Decreasing</TD>
              <TD BGCOLOR="#ffdddd"></TD>
             </TR>
            </TABLE>
           >"""
        graph.node("legend", label=legend, shape="rectangle", color="white")
613
614
615
    return graph


616
617
618
619
620
def create_tree_digraph(
    booster: Union[Booster, LGBMModel],
    tree_index: int = 0,
    show_info: Optional[List[str]] = None,
    precision: Optional[int] = 3,
621
    orientation: str = "horizontal",
622
    example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
623
    max_category_values: int = 10,
624
    **kwargs: Any,
625
) -> Any:
626
    """Create a digraph representation of specified tree.
627

628
629
630
631
632
633
634
635
636
637
    Each node in the graph represents a node in the tree.

    Non-leaf nodes have labels like ``Column_10 <= 875.9``, which means
    "this node splits on the feature named "Column_10", with threshold 875.9".

    Leaf nodes have labels like ``leaf 2: 0.422``, which means "this node is a
    leaf node, and the predicted value for records that fall into this node
    is 0.422". The number (``2``) is an internal unique identifier and doesn't
    have any special meaning.

Nikita Titov's avatar
Nikita Titov committed
638
639
640
641
    .. note::

        For more information please visit
        https://graphviz.readthedocs.io/en/stable/api.html#digraph.
642

643
644
    Parameters
    ----------
645
    booster : Booster or LGBMModel
646
        Booster or LGBMModel instance to be converted.
647
648
    tree_index : int, optional (default=0)
        The index of a target tree to convert.
649
    show_info : list of str, or None, optional (default=None)
650
        What information should be shown in nodes.
651
652
653
654
655
656

            - ``'split_gain'`` : gain from adding this split to the model
            - ``'internal_value'`` : raw predicted value that would be produced by this node if it was a leaf node
            - ``'internal_count'`` : number of records from the training data that fall into this non-leaf node
            - ``'internal_weight'`` : total weight of all nodes that fall into this non-leaf node
            - ``'leaf_count'`` : number of records from the training data that fall into this leaf node
657
            - ``'leaf_weight'`` : total weight (sum of Hessian) of all observations that fall into this leaf node
658
            - ``'data_percentage'`` : percentage of training data that fall into this node
659
    precision : int or None, optional (default=3)
660
        Used to restrict the display of floating point values to a certain precision.
661
    orientation : str, optional (default='horizontal')
662
663
        Orientation of the tree.
        Can be 'horizontal' or 'vertical'.
664
665
666
    example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
        Single row with the same structure as the training data.
        If not None, the plot will highlight the path that sample takes through the tree.
667
668
669

        .. versionadded:: 4.0.0

670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
    max_category_values : int, optional (default=10)
        The maximum number of category values to display in tree nodes, if the number of thresholds is greater than this value, thresholds will be collapsed and displayed on the label tooltip instead.

        .. warning::

            Consider wrapping the SVG string of the tree graph with ``IPython.display.HTML`` when running on JupyterLab to get the `tooltip <https://graphviz.org/docs/attrs/tooltip>`_ working right.

            Example:

            .. code-block:: python

                from IPython.display import HTML

                graph = lgb.create_tree_digraph(clf, max_category_values=5)
                HTML(graph._repr_image_svg_xml())

686
687
        .. versionadded:: 4.0.0

688
    **kwargs
689
690
        Other parameters passed to ``Digraph`` constructor.
        Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
691
692
693

    Returns
    -------
694
695
    graph : graphviz.Digraph
        The digraph representation of specified tree.
696
697
698
699
    """
    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
700
        raise TypeError("booster must be Booster or LGBMModel.")
701
702

    model = booster.dump_model()
703
704
705
    tree_infos = model["tree_info"]
    feature_names = model.get("feature_names", None)
    monotone_constraints = model.get("monotone_constraints", None)
706

707
708
709
    if tree_index < len(tree_infos):
        tree_info = tree_infos[tree_index]
    else:
710
        raise IndexError("tree_index is out of range.")
711
712
713
714

    if show_info is None:
        show_info = []

715
716
    if example_case is not None:
        if not isinstance(example_case, (np.ndarray, pd_DataFrame)) or example_case.ndim != 2:
717
            raise ValueError("example_case must be a numpy 2-D array or a pandas DataFrame")
718
        if example_case.shape[0] != 1:
719
            raise ValueError("example_case must have a single row.")
720
        if isinstance(example_case, pd_DataFrame):
721
722
            example_case = _data_from_pandas(
                data=example_case,
723
724
                feature_name="auto",
                categorical_feature="auto",
725
                pandas_categorical=booster.pandas_categorical,
726
            )[0]
727
728
        example_case = example_case[0]

729
    return _to_graphviz(
730
731
732
733
734
735
736
        tree_info=tree_info,
        show_info=show_info,
        feature_names=feature_names,
        precision=precision,
        orientation=orientation,
        constraints=monotone_constraints,
        example_case=example_case,
737
        max_category_values=max_category_values,
738
        **kwargs,
739
    )
740

wxchan's avatar
wxchan committed
741

742
743
def plot_tree(
    booster: Union[Booster, LGBMModel],
744
    ax: "Optional[matplotlib.axes.Axes]" = None,
745
746
747
748
749
    tree_index: int = 0,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    show_info: Optional[List[str]] = None,
    precision: Optional[int] = 3,
750
    orientation: str = "horizontal",
751
    example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
752
    **kwargs: Any,
753
) -> Any:
wxchan's avatar
wxchan committed
754
755
    """Plot specified tree.

756
757
758
759
760
761
762
763
764
765
    Each node in the graph represents a node in the tree.

    Non-leaf nodes have labels like ``Column_10 <= 875.9``, which means
    "this node splits on the feature named "Column_10", with threshold 875.9".

    Leaf nodes have labels like ``leaf 2: 0.422``, which means "this node is a
    leaf node, and the predicted value for records that fall into this node
    is 0.422". The number (``2``) is an internal unique identifier and doesn't
    have any special meaning.

Nikita Titov's avatar
Nikita Titov committed
766
767
768
769
    .. note::

        It is preferable to use ``create_tree_digraph()`` because of its lossless quality
        and returned objects can be also rendered and displayed directly inside a Jupyter notebook.
770

wxchan's avatar
wxchan committed
771
772
    Parameters
    ----------
773
774
775
776
777
778
779
780
    booster : Booster or LGBMModel
        Booster or LGBMModel instance to be plotted.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    tree_index : int, optional (default=0)
        The index of a target tree to plot.
    figsize : tuple of 2 elements or None, optional (default=None)
wxchan's avatar
wxchan committed
781
        Figure size.
782
783
    dpi : int or None, optional (default=None)
        Resolution of the figure.
784
    show_info : list of str, or None, optional (default=None)
785
        What information should be shown in nodes.
786
787
788
789
790
791

            - ``'split_gain'`` : gain from adding this split to the model
            - ``'internal_value'`` : raw predicted value that would be produced by this node if it was a leaf node
            - ``'internal_count'`` : number of records from the training data that fall into this non-leaf node
            - ``'internal_weight'`` : total weight of all nodes that fall into this non-leaf node
            - ``'leaf_count'`` : number of records from the training data that fall into this leaf node
792
            - ``'leaf_weight'`` : total weight (sum of Hessian) of all observations that fall into this leaf node
793
            - ``'data_percentage'`` : percentage of training data that fall into this node
794
    precision : int or None, optional (default=3)
795
        Used to restrict the display of floating point values to a certain precision.
796
    orientation : str, optional (default='horizontal')
797
798
        Orientation of the tree.
        Can be 'horizontal' or 'vertical'.
799
800
801
    example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
        Single row with the same structure as the training data.
        If not None, the plot will highlight the path that sample takes through the tree.
802
803
804

        .. versionadded:: 4.0.0

805
    **kwargs
806
807
        Other parameters passed to ``Digraph`` constructor.
        Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
wxchan's avatar
wxchan committed
808
809
810

    Returns
    -------
811
812
    ax : matplotlib.axes.Axes
        The plot with single tree.
wxchan's avatar
wxchan committed
813
    """
814
    if MATPLOTLIB_INSTALLED:
815
816
        import matplotlib.image  # noqa: PLC0415
        import matplotlib.pyplot as plt  # noqa: PLC0415
817
    else:
818
        raise ImportError("You must install matplotlib and restart your session to plot tree.")
wxchan's avatar
wxchan committed
819
820

    if ax is None:
821
        if figsize is not None:
822
            _check_not_tuple_of_2_elements(figsize, "figsize")
823
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
wxchan's avatar
wxchan committed
824

825
826
827
828
829
830
831
832
833
    graph = create_tree_digraph(
        booster=booster,
        tree_index=tree_index,
        show_info=show_info,
        precision=precision,
        orientation=orientation,
        example_case=example_case,
        **kwargs,
    )
wxchan's avatar
wxchan committed
834
835

    s = BytesIO()
836
    s.write(graph.pipe(format="png"))
wxchan's avatar
wxchan committed
837
    s.seek(0)
838
    img = matplotlib.image.imread(s)
wxchan's avatar
wxchan committed
839
840

    ax.imshow(img)
841
    ax.axis("off")
wxchan's avatar
wxchan committed
842
    return ax