plotting.py 31.8 KB
Newer Older
1
# coding: utf-8
2
"""Plotting library."""
3
import math
4
from copy import deepcopy
wxchan's avatar
wxchan committed
5
from io import BytesIO
6
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
wxchan's avatar
wxchan committed
7

8
9
import numpy as np

10
from .basic import Booster, _data_from_pandas, _is_zero, _log_warning, _MissingType
11
from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, pd_DataFrame
12
13
from .sklearn import LGBMModel

14
__all__ = [
15
16
17
18
19
    "create_tree_digraph",
    "plot_importance",
    "plot_metric",
    "plot_split_value_histogram",
    "plot_tree",
20
21
]

22
23
24
if TYPE_CHECKING:
    import matplotlib

25

26
def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None:
27
    """Check object is not tuple or does not have 2 elements."""
28
    if not isinstance(obj, tuple) or len(obj) != 2:
29
        raise TypeError(f"{obj_name} must be a tuple of 2 elements.")
wxchan's avatar
wxchan committed
30
31


32
def _float2str(value: float, precision: Optional[int]) -> str:
33
    return f"{value:.{precision}f}" if precision is not None and not isinstance(value, str) else str(value)
34
35


36
37
def plot_importance(
    booster: Union[Booster, LGBMModel],
38
    ax: "Optional[matplotlib.axes.Axes]" = None,
39
40
41
    height: float = 0.2,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
42
43
44
45
    title: Optional[str] = "Feature importance",
    xlabel: Optional[str] = "Feature importance",
    ylabel: Optional[str] = "Features",
    importance_type: str = "auto",
46
47
48
49
50
51
    max_num_features: Optional[int] = None,
    ignore_zero: bool = True,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    grid: bool = True,
    precision: Optional[int] = 3,
52
    **kwargs: Any,
53
) -> Any:
54
    """Plot model's feature importances.
55
56
57

    Parameters
    ----------
wxchan's avatar
wxchan committed
58
    booster : Booster or LGBMModel
59
60
61
62
63
64
65
66
67
68
        Booster or LGBMModel instance which feature importance should be plotted.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    height : float, optional (default=0.2)
        Bar height, passed to ``ax.barh()``.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
69
    title : str or None, optional (default="Feature importance")
70
71
        Axes title.
        If None, title is disabled.
72
    xlabel : str or None, optional (default="Feature importance")
73
74
        X-axis title label.
        If None, title is disabled.
75
        @importance_type@ placeholder can be used, and it will be replaced with the value of ``importance_type`` parameter.
76
    ylabel : str or None, optional (default="Features")
77
78
        Y-axis title label.
        If None, title is disabled.
79
    importance_type : str, optional (default="auto")
80
        How the importance is calculated.
81
        If "auto", if ``booster`` parameter is LGBMModel, ``booster.importance_type`` attribute is used; "split" otherwise.
82
83
84
        If "split", result contains numbers of times the feature is used in a model.
        If "gain", result contains total gains of splits which use the feature.
    max_num_features : int or None, optional (default=None)
85
        Max number of top features displayed on plot.
86
87
88
89
90
        If None or <1, all features will be displayed.
    ignore_zero : bool, optional (default=True)
        Whether to ignore features with zero importance.
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
91
92
    dpi : int or None, optional (default=None)
        Resolution of the figure.
93
94
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
95
    precision : int or None, optional (default=3)
96
        Used to restrict the display of floating point values to a certain precision.
97
    **kwargs
98
        Other parameters passed to ``ax.barh()``.
99
100
101

    Returns
    -------
102
103
    ax : matplotlib.axes.Axes
        The plot with model's feature importances.
104
    """
105
    if MATPLOTLIB_INSTALLED:
106
        import matplotlib.pyplot as plt
107
    else:
108
        raise ImportError("You must install matplotlib and restart your session to plot importance.")
109
110

    if isinstance(booster, LGBMModel):
111
112
        if importance_type == "auto":
            importance_type = booster.importance_type
wxchan's avatar
wxchan committed
113
        booster = booster.booster_
114
115
116
117
    elif isinstance(booster, Booster):
        if importance_type == "auto":
            importance_type = "split"
    else:
118
        raise TypeError("booster must be Booster or LGBMModel.")
wxchan's avatar
wxchan committed
119
120
121

    importance = booster.feature_importance(importance_type=importance_type)
    feature_name = booster.feature_name()
122
123

    if not len(importance):
124
        raise ValueError("Booster's feature_importance is empty.")
125

126
    tuples = sorted(zip(feature_name, importance), key=lambda x: x[1])
127
128
129
130
    if ignore_zero:
        tuples = [x for x in tuples if x[1] > 0]
    if max_num_features is not None and max_num_features > 0:
        tuples = tuples[-max_num_features:]
131
    labels, values = zip(*tuples)
132
133

    if ax is None:
134
        if figsize is not None:
135
            _check_not_tuple_of_2_elements(figsize, "figsize")
136
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
137
138

    ylocs = np.arange(len(values))
139
    ax.barh(ylocs, values, align="center", height=height, **kwargs)
140

141
    for x, y in zip(values, ylocs):
142
        ax.text(x + 1, y, _float2str(x, precision) if importance_type == "gain" else x, va="center")
143
144
145
146
147

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
148
        _check_not_tuple_of_2_elements(xlim, "xlim")
149
150
151
152
153
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
154
        _check_not_tuple_of_2_elements(ylim, "ylim")
155
156
157
158
159
160
161
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
162
        xlabel = xlabel.replace("@importance_type@", importance_type)
163
164
165
166
167
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax
wxchan's avatar
wxchan committed
168
169


170
171
172
173
def plot_split_value_histogram(
    booster: Union[Booster, LGBMModel],
    feature: Union[int, str],
    bins: Union[int, str, None] = None,
174
    ax: "Optional[matplotlib.axes.Axes]" = None,
175
176
177
    width_coef: float = 0.8,
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
178
179
180
    title: Optional[str] = "Split value histogram for feature with @index/name@ @feature@",
    xlabel: Optional[str] = "Feature split value",
    ylabel: Optional[str] = "Count",
181
182
183
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    grid: bool = True,
184
    **kwargs: Any,
185
) -> Any:
186
187
188
189
190
191
    """Plot split value histogram for the specified feature of the model.

    Parameters
    ----------
    booster : Booster or LGBMModel
        Booster or LGBMModel instance of which feature split value histogram should be plotted.
192
    feature : int or str
193
194
        The feature name or index the histogram is plotted for.
        If int, interpreted as index.
195
196
        If str, interpreted as name.
    bins : int, str or None, optional (default=None)
197
198
        The maximum number of bins.
        If None, the number of bins equals number of unique split values.
199
        If str, it should be one from the list of the supported values by ``numpy.histogram()`` function.
200
201
202
203
204
205
206
207
208
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    width_coef : float, optional (default=0.8)
        Coefficient for histogram bar width.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
209
    title : str or None, optional (default="Split value histogram for feature with @index/name@ @feature@")
210
211
212
213
214
        Axes title.
        If None, title is disabled.
        @feature@ placeholder can be used, and it will be replaced with the value of ``feature`` parameter.
        @index/name@ placeholder can be used,
        and it will be replaced with ``index`` word in case of ``int`` type ``feature`` parameter
215
216
        or ``name`` word in case of ``str`` type ``feature`` parameter.
    xlabel : str or None, optional (default="Feature split value")
217
218
        X-axis title label.
        If None, title is disabled.
219
    ylabel : str or None, optional (default="Count")
220
221
222
223
        Y-axis title label.
        If None, title is disabled.
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
224
225
    dpi : int or None, optional (default=None)
        Resolution of the figure.
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
    **kwargs
        Other parameters passed to ``ax.bar()``.

    Returns
    -------
    ax : matplotlib.axes.Axes
        The plot with specified model's feature split value histogram.
    """
    if MATPLOTLIB_INSTALLED:
        import matplotlib.pyplot as plt
        from matplotlib.ticker import MaxNLocator
    else:
240
        raise ImportError("You must install matplotlib and restart your session to plot split value histogram.")
241
242
243
244

    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
245
        raise TypeError("booster must be Booster or LGBMModel.")
246

247
    hist, split_bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
248
    if np.count_nonzero(hist) == 0:
249
        raise ValueError("Cannot plot split value histogram, " f"because feature {feature} was not used in splitting")
250
251
    width = width_coef * (split_bins[1] - split_bins[0])
    centred = (split_bins[:-1] + split_bins[1:]) / 2
252
253
254

    if ax is None:
        if figsize is not None:
255
            _check_not_tuple_of_2_elements(figsize, "figsize")
256
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
257

258
    ax.bar(centred, hist, align="center", width=width, **kwargs)
259
260

    if xlim is not None:
261
        _check_not_tuple_of_2_elements(xlim, "xlim")
262
    else:
263
264
        range_result = split_bins[-1] - split_bins[0]
        xlim = (split_bins[0] - range_result * 0.2, split_bins[-1] + range_result * 0.2)
265
266
267
268
    ax.set_xlim(xlim)

    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    if ylim is not None:
269
        _check_not_tuple_of_2_elements(ylim, "ylim")
270
271
272
273
274
    else:
        ylim = (0, max(hist) * 1.1)
    ax.set_ylim(ylim)

    if title is not None:
275
276
        title = title.replace("@feature@", str(feature))
        title = title.replace("@index/name@", ("name" if isinstance(feature, str) else "index"))
277
278
279
280
281
282
283
284
285
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


286
287
288
289
def plot_metric(
    booster: Union[Dict, LGBMModel],
    metric: Optional[str] = None,
    dataset_names: Optional[List[str]] = None,
290
    ax: "Optional[matplotlib.axes.Axes]" = None,
291
292
    xlim: Optional[Tuple[float, float]] = None,
    ylim: Optional[Tuple[float, float]] = None,
293
294
295
    title: Optional[str] = "Metric during training",
    xlabel: Optional[str] = "Iterations",
    ylabel: Optional[str] = "@metric@",
296
297
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
298
    grid: bool = True,
299
) -> Any:
300
301
302
303
304
    """Plot one metric during training.

    Parameters
    ----------
    booster : dict or LGBMModel
305
        Dictionary returned from ``lightgbm.train()`` or LGBMModel instance.
306
    metric : str or None, optional (default=None)
307
308
        The metric name to plot.
        Only one metric supported because different metrics have various scales.
309
        If None, first metric picked from dictionary (according to hashcode).
310
    dataset_names : list of str, or None, optional (default=None)
311
312
313
314
315
316
317
318
319
        List of the dataset names which are used to calculate metric to plot.
        If None, all datasets are used.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    xlim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.xlim()``.
    ylim : tuple of 2 elements or None, optional (default=None)
        Tuple passed to ``ax.ylim()``.
320
    title : str or None, optional (default="Metric during training")
321
322
        Axes title.
        If None, title is disabled.
323
    xlabel : str or None, optional (default="Iterations")
324
325
        X-axis title label.
        If None, title is disabled.
326
    ylabel : str or None, optional (default="@metric@")
327
328
329
        Y-axis title label.
        If 'auto', metric name is used.
        If None, title is disabled.
330
        @metric@ placeholder can be used, and it will be replaced with metric name.
331
332
    figsize : tuple of 2 elements or None, optional (default=None)
        Figure size.
333
334
    dpi : int or None, optional (default=None)
        Resolution of the figure.
335
336
    grid : bool, optional (default=True)
        Whether to add a grid for axes.
337
338
339

    Returns
    -------
340
341
    ax : matplotlib.axes.Axes
        The plot with metric's history over the training.
342
    """
343
    if MATPLOTLIB_INSTALLED:
344
        import matplotlib.pyplot as plt
345
    else:
346
        raise ImportError("You must install matplotlib and restart your session to plot metric.")
347
348
349
350
351

    if isinstance(booster, LGBMModel):
        eval_results = deepcopy(booster.evals_result_)
    elif isinstance(booster, dict):
        eval_results = deepcopy(booster)
352
    elif isinstance(booster, Booster):
353
354
355
        raise TypeError(
            "booster must be dict or LGBMModel. To use plot_metric with Booster type, first record the metrics using record_evaluation callback then pass that to plot_metric as argument `booster`"
        )
356
    else:
357
        raise TypeError("booster must be dict or LGBMModel.")
358
359
360
361

    num_data = len(eval_results)

    if not num_data:
362
        raise ValueError("eval results cannot be empty.")
363
364
365

    if ax is None:
        if figsize is not None:
366
            _check_not_tuple_of_2_elements(figsize, "figsize")
367
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
368
369

    if dataset_names is None:
370
        dataset_names_iter = iter(eval_results.keys())
371
    elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
372
        raise ValueError("dataset_names should be iterable and cannot be empty")
373
    else:
374
        dataset_names_iter = iter(dataset_names)
375

376
    name = next(dataset_names_iter)  # take one as sample
377
378
379
380
    metrics_for_one = eval_results[name]
    num_metric = len(metrics_for_one)
    if metric is None:
        if num_metric > 1:
381
            _log_warning("More than one metric available, picking one to plot.")
382
383
384
        metric, results = metrics_for_one.popitem()
    else:
        if metric not in metrics_for_one:
385
            raise KeyError("No given metric in eval results.")
386
        results = metrics_for_one[metric]
387
388
389
    num_iteration = len(results)
    max_result = max(results)
    min_result = min(results)
390
    x_ = range(num_iteration)
391
392
    ax.plot(x_, results, label=name)

393
    for name in dataset_names_iter:
394
395
        metrics_for_one = eval_results[name]
        results = metrics_for_one[metric]
396
397
        max_result = max(*results, max_result)
        min_result = min(*results, min_result)
398
399
        ax.plot(x_, results, label=name)

400
    ax.legend(loc="best")
401
402

    if xlim is not None:
403
        _check_not_tuple_of_2_elements(xlim, "xlim")
404
405
406
407
408
    else:
        xlim = (0, num_iteration)
    ax.set_xlim(xlim)

    if ylim is not None:
409
        _check_not_tuple_of_2_elements(ylim, "ylim")
410
411
412
413
414
415
416
417
418
419
    else:
        range_result = max_result - min_result
        ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
420
        ylabel = ylabel.replace("@metric@", metric)
421
422
423
424
425
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax


426
427
428
429
430
431
432
433
434
def _determine_direction_for_numeric_split(
    fval: float,
    threshold: float,
    missing_type_str: str,
    default_left: bool,
) -> str:
    missing_type = _MissingType(missing_type_str)
    if math.isnan(fval) and missing_type != _MissingType.NAN:
        fval = 0.0
435
436
437
438
    if (missing_type == _MissingType.ZERO and _is_zero(fval)) or (
        missing_type == _MissingType.NAN and math.isnan(fval)
    ):
        direction = "left" if default_left else "right"
439
    else:
440
        direction = "left" if fval <= threshold else "right"
441
442
443
444
445
    return direction


def _determine_direction_for_categorical_split(fval: float, thresholds: str) -> str:
    if math.isnan(fval) or int(fval) < 0:
446
447
448
        return "right"
    int_thresholds = {int(t) for t in thresholds.split("||")}
    return "left" if int(fval) in int_thresholds else "right"
449
450


451
452
453
454
def _to_graphviz(
    tree_info: Dict[str, Any],
    show_info: List[str],
    feature_names: Union[List[str], None],
455
456
457
458
    precision: Optional[int],
    orientation: str,
    constraints: Optional[List[int]],
    example_case: Optional[Union[np.ndarray, pd_DataFrame]],
459
    max_category_values: int,
460
    **kwargs: Any,
461
) -> Any:
462
463
464
    """Convert specified tree to graphviz instance.

    See:
465
      - https://graphviz.readthedocs.io/en/stable/api.html#digraph
466
    """
467
    if GRAPHVIZ_INSTALLED:
468
        from graphviz import Digraph
469
    else:
470
        raise ImportError("You must install graphviz and restart your session to plot tree.")
wxchan's avatar
wxchan committed
471

472
    def add(
473
        root: Dict[str, Any], total_count: int, parent: Optional[str], decision: Optional[str], highlight: bool
474
    ) -> None:
475
        """Recursively add node or edge."""
476
477
        fillcolor = "white"
        style = ""
478
        tooltip = None
479
        if highlight:
480
481
            color = "blue"
            penwidth = "3"
482
        else:
483
484
485
            color = "black"
            penwidth = "1"
        if "split_index" in root:  # non-leaf
486
            shape = "rectangle"
487
488
489
490
            l_dec = "yes"
            r_dec = "no"
            threshold = root["threshold"]
            if root["decision_type"] == "<=":
491
                operator = "&#8804;"
492
            elif root["decision_type"] == "==":
493
                operator = "="
494
            else:
495
                raise ValueError("Invalid decision type in tree model.")
496
            name = f"split{root['split_index']}"
497
            split_feature = root["split_feature"]
498
            if feature_names is not None:
499
                label = f"<B>{feature_names[split_feature]}</B> {operator}"
500
            else:
501
502
503
                label = f"feature <B>{split_feature}</B> {operator} "
            direction = None
            if example_case is not None:
504
                if root["decision_type"] == "==":
505
                    direction = _determine_direction_for_categorical_split(
506
                        fval=example_case[split_feature], thresholds=root["threshold"]
507
                    )
508
509
                else:
                    direction = _determine_direction_for_numeric_split(
510
                        fval=example_case[split_feature],
511
512
513
                        threshold=root["threshold"],
                        missing_type_str=root["missing_type"],
                        default_left=root["default_left"],
514
                    )
515
516
            if root["decision_type"] == "==":
                category_values = root["threshold"].split("||")
517
                if len(category_values) > max_category_values:
518
519
                    tooltip = root["threshold"]
                    threshold = "||".join(category_values[:2]) + "||...||" + category_values[-1]
520
521

            label += f"<B>{_float2str(threshold, precision)}</B>"
522
            for info in ["split_gain", "internal_value", "internal_weight", "internal_count", "data_percentage"]:
523
                if info in show_info:
524
525
                    output = info.split("_")[-1]
                    if info in {"split_gain", "internal_value", "internal_weight"}:
526
                        label += f"<br/>{_float2str(root[info], precision)} {output}"
527
                    elif info == "internal_count":
528
                        label += f"<br/>{output}: {root[info]}"
529
                    elif info == "data_percentage":
530
                        label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data"
531
532

            if constraints:
533
                if constraints[root["split_feature"]] == 1:
534
                    fillcolor = "#ddffdd"  # light green
535
                if constraints[root["split_feature"]] == -1:
536
537
                    fillcolor = "#ffdddd"  # light red
                style = "filled"
538
            label = f"<{label}>"
539
            add(
540
                root=root["left_child"],
541
542
543
                total_count=total_count,
                parent=name,
                decision=l_dec,
544
                highlight=highlight and direction == "left",
545
546
            )
            add(
547
                root=root["right_child"],
548
549
550
                total_count=total_count,
                parent=name,
                decision=r_dec,
551
                highlight=highlight and direction == "right",
552
            )
wxchan's avatar
wxchan committed
553
        else:  # leaf
554
            shape = "ellipse"
555
556
557
            name = f"leaf{root['leaf_index']}"
            label = f"leaf {root['leaf_index']}: "
            label += f"<B>{_float2str(root['leaf_value'], precision)}</B>"
558
            if "leaf_weight" in show_info:
559
                label += f"<br/>{_float2str(root['leaf_weight'], precision)} weight"
560
            if "leaf_count" in show_info:
561
                label += f"<br/>count: {root['leaf_count']}"
562
            if "data_percentage" in show_info:
563
564
                label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data"
            label = f"<{label}>"
565
566
567
568
569
570
571
572
573
574
        graph.node(
            name,
            label=label,
            shape=shape,
            style=style,
            fillcolor=fillcolor,
            color=color,
            penwidth=penwidth,
            tooltip=tooltip,
        )
wxchan's avatar
wxchan committed
575
        if parent is not None:
576
            graph.edge(parent, name, decision, color=color, penwidth=penwidth)
wxchan's avatar
wxchan committed
577

578
    graph = Digraph(**kwargs)
579
580
    rankdir = "LR" if orientation == "horizontal" else "TB"
    graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
581
    if "internal_count" in tree_info["tree_structure"]:
582
        add(
583
584
            root=tree_info["tree_structure"],
            total_count=tree_info["tree_structure"]["internal_count"],
585
586
            parent=None,
            decision=None,
587
            highlight=example_case is not None,
588
        )
589
    else:
590
        raise Exception("Cannot plot trees with no split")
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609

    if constraints:
        # "#ddffdd" is light green, "#ffdddd" is light red
        legend = """<
            <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">
             <TR>
              <TD COLSPAN="2"><B>Monotone constraints</B></TD>
             </TR>
             <TR>
              <TD>Increasing</TD>
              <TD BGCOLOR="#ddffdd"></TD>
             </TR>
             <TR>
              <TD>Decreasing</TD>
              <TD BGCOLOR="#ffdddd"></TD>
             </TR>
            </TABLE>
           >"""
        graph.node("legend", label=legend, shape="rectangle", color="white")
610
611
612
    return graph


613
614
615
616
617
def create_tree_digraph(
    booster: Union[Booster, LGBMModel],
    tree_index: int = 0,
    show_info: Optional[List[str]] = None,
    precision: Optional[int] = 3,
618
    orientation: str = "horizontal",
619
    example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
620
    max_category_values: int = 10,
621
    **kwargs: Any,
622
) -> Any:
623
    """Create a digraph representation of specified tree.
624

625
626
627
628
629
630
631
632
633
634
    Each node in the graph represents a node in the tree.

    Non-leaf nodes have labels like ``Column_10 <= 875.9``, which means
    "this node splits on the feature named "Column_10", with threshold 875.9".

    Leaf nodes have labels like ``leaf 2: 0.422``, which means "this node is a
    leaf node, and the predicted value for records that fall into this node
    is 0.422". The number (``2``) is an internal unique identifier and doesn't
    have any special meaning.

Nikita Titov's avatar
Nikita Titov committed
635
636
637
638
    .. note::

        For more information please visit
        https://graphviz.readthedocs.io/en/stable/api.html#digraph.
639

640
641
    Parameters
    ----------
642
    booster : Booster or LGBMModel
643
        Booster or LGBMModel instance to be converted.
644
645
    tree_index : int, optional (default=0)
        The index of a target tree to convert.
646
    show_info : list of str, or None, optional (default=None)
647
        What information should be shown in nodes.
648
649
650
651
652
653

            - ``'split_gain'`` : gain from adding this split to the model
            - ``'internal_value'`` : raw predicted value that would be produced by this node if it was a leaf node
            - ``'internal_count'`` : number of records from the training data that fall into this non-leaf node
            - ``'internal_weight'`` : total weight of all nodes that fall into this non-leaf node
            - ``'leaf_count'`` : number of records from the training data that fall into this leaf node
654
            - ``'leaf_weight'`` : total weight (sum of Hessian) of all observations that fall into this leaf node
655
            - ``'data_percentage'`` : percentage of training data that fall into this node
656
    precision : int or None, optional (default=3)
657
        Used to restrict the display of floating point values to a certain precision.
658
    orientation : str, optional (default='horizontal')
659
660
        Orientation of the tree.
        Can be 'horizontal' or 'vertical'.
661
662
663
    example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
        Single row with the same structure as the training data.
        If not None, the plot will highlight the path that sample takes through the tree.
664
665
666

        .. versionadded:: 4.0.0

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
    max_category_values : int, optional (default=10)
        The maximum number of category values to display in tree nodes, if the number of thresholds is greater than this value, thresholds will be collapsed and displayed on the label tooltip instead.

        .. warning::

            Consider wrapping the SVG string of the tree graph with ``IPython.display.HTML`` when running on JupyterLab to get the `tooltip <https://graphviz.org/docs/attrs/tooltip>`_ working right.

            Example:

            .. code-block:: python

                from IPython.display import HTML

                graph = lgb.create_tree_digraph(clf, max_category_values=5)
                HTML(graph._repr_image_svg_xml())

683
684
        .. versionadded:: 4.0.0

685
    **kwargs
686
687
        Other parameters passed to ``Digraph`` constructor.
        Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
688
689
690

    Returns
    -------
691
692
    graph : graphviz.Digraph
        The digraph representation of specified tree.
693
694
695
696
    """
    if isinstance(booster, LGBMModel):
        booster = booster.booster_
    elif not isinstance(booster, Booster):
697
        raise TypeError("booster must be Booster or LGBMModel.")
698
699

    model = booster.dump_model()
700
701
702
    tree_infos = model["tree_info"]
    feature_names = model.get("feature_names", None)
    monotone_constraints = model.get("monotone_constraints", None)
703

704
705
706
    if tree_index < len(tree_infos):
        tree_info = tree_infos[tree_index]
    else:
707
        raise IndexError("tree_index is out of range.")
708
709
710
711

    if show_info is None:
        show_info = []

712
713
    if example_case is not None:
        if not isinstance(example_case, (np.ndarray, pd_DataFrame)) or example_case.ndim != 2:
714
            raise ValueError("example_case must be a numpy 2-D array or a pandas DataFrame")
715
        if example_case.shape[0] != 1:
716
            raise ValueError("example_case must have a single row.")
717
        if isinstance(example_case, pd_DataFrame):
718
719
            example_case = _data_from_pandas(
                data=example_case,
720
721
                feature_name="auto",
                categorical_feature="auto",
722
                pandas_categorical=booster.pandas_categorical,
723
            )[0]
724
725
        example_case = example_case[0]

726
    return _to_graphviz(
727
728
729
730
731
732
733
        tree_info=tree_info,
        show_info=show_info,
        feature_names=feature_names,
        precision=precision,
        orientation=orientation,
        constraints=monotone_constraints,
        example_case=example_case,
734
        max_category_values=max_category_values,
735
        **kwargs,
736
    )
737

wxchan's avatar
wxchan committed
738

739
740
def plot_tree(
    booster: Union[Booster, LGBMModel],
741
    ax: "Optional[matplotlib.axes.Axes]" = None,
742
743
744
745
746
    tree_index: int = 0,
    figsize: Optional[Tuple[float, float]] = None,
    dpi: Optional[int] = None,
    show_info: Optional[List[str]] = None,
    precision: Optional[int] = 3,
747
    orientation: str = "horizontal",
748
    example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
749
    **kwargs: Any,
750
) -> Any:
wxchan's avatar
wxchan committed
751
752
    """Plot specified tree.

753
754
755
756
757
758
759
760
761
762
    Each node in the graph represents a node in the tree.

    Non-leaf nodes have labels like ``Column_10 <= 875.9``, which means
    "this node splits on the feature named "Column_10", with threshold 875.9".

    Leaf nodes have labels like ``leaf 2: 0.422``, which means "this node is a
    leaf node, and the predicted value for records that fall into this node
    is 0.422". The number (``2``) is an internal unique identifier and doesn't
    have any special meaning.

Nikita Titov's avatar
Nikita Titov committed
763
764
765
766
    .. note::

        It is preferable to use ``create_tree_digraph()`` because of its lossless quality
        and returned objects can be also rendered and displayed directly inside a Jupyter notebook.
767

wxchan's avatar
wxchan committed
768
769
    Parameters
    ----------
770
771
772
773
774
775
776
777
    booster : Booster or LGBMModel
        Booster or LGBMModel instance to be plotted.
    ax : matplotlib.axes.Axes or None, optional (default=None)
        Target axes instance.
        If None, new figure and axes will be created.
    tree_index : int, optional (default=0)
        The index of a target tree to plot.
    figsize : tuple of 2 elements or None, optional (default=None)
wxchan's avatar
wxchan committed
778
        Figure size.
779
780
    dpi : int or None, optional (default=None)
        Resolution of the figure.
781
    show_info : list of str, or None, optional (default=None)
782
        What information should be shown in nodes.
783
784
785
786
787
788

            - ``'split_gain'`` : gain from adding this split to the model
            - ``'internal_value'`` : raw predicted value that would be produced by this node if it was a leaf node
            - ``'internal_count'`` : number of records from the training data that fall into this non-leaf node
            - ``'internal_weight'`` : total weight of all nodes that fall into this non-leaf node
            - ``'leaf_count'`` : number of records from the training data that fall into this leaf node
789
            - ``'leaf_weight'`` : total weight (sum of Hessian) of all observations that fall into this leaf node
790
            - ``'data_percentage'`` : percentage of training data that fall into this node
791
    precision : int or None, optional (default=3)
792
        Used to restrict the display of floating point values to a certain precision.
793
    orientation : str, optional (default='horizontal')
794
795
        Orientation of the tree.
        Can be 'horizontal' or 'vertical'.
796
797
798
    example_case : numpy 2-D array, pandas DataFrame or None, optional (default=None)
        Single row with the same structure as the training data.
        If not None, the plot will highlight the path that sample takes through the tree.
799
800
801

        .. versionadded:: 4.0.0

802
    **kwargs
803
804
        Other parameters passed to ``Digraph`` constructor.
        Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
wxchan's avatar
wxchan committed
805
806
807

    Returns
    -------
808
809
    ax : matplotlib.axes.Axes
        The plot with single tree.
wxchan's avatar
wxchan committed
810
    """
811
    if MATPLOTLIB_INSTALLED:
812
        import matplotlib.image
813
        import matplotlib.pyplot as plt
814
    else:
815
        raise ImportError("You must install matplotlib and restart your session to plot tree.")
wxchan's avatar
wxchan committed
816
817

    if ax is None:
818
        if figsize is not None:
819
            _check_not_tuple_of_2_elements(figsize, "figsize")
820
        _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
wxchan's avatar
wxchan committed
821

822
823
824
825
826
827
828
829
830
    graph = create_tree_digraph(
        booster=booster,
        tree_index=tree_index,
        show_info=show_info,
        precision=precision,
        orientation=orientation,
        example_case=example_case,
        **kwargs,
    )
wxchan's avatar
wxchan committed
831
832

    s = BytesIO()
833
    s.write(graph.pipe(format="png"))
wxchan's avatar
wxchan committed
834
    s.seek(0)
835
    img = matplotlib.image.imread(s)
wxchan's avatar
wxchan committed
836
837

    ax.imshow(img)
838
    ax.axis("off")
wxchan's avatar
wxchan committed
839
    return ax