test_plotting.py 8.67 KB
Newer Older
1
2
3
4
5
# coding: utf-8
# pylint: skip-file
import unittest

import lightgbm as lgb
6
from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
7
8
9
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

10
if MATPLOTLIB_INSTALLED:
wxchan's avatar
wxchan committed
11
12
    import matplotlib
    matplotlib.use('Agg')
13
14
if GRAPHVIZ_INSTALLED:
    import graphviz
15
16
17
18


class TestBasic(unittest.TestCase):

19
    def setUp(self):
20
21
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(*load_breast_cancer(True),
                                                                                test_size=0.1, random_state=1)
22
23
        self.train_data = lgb.Dataset(self.X_train, self.y_train)
        self.params = {
24
25
26
27
            "objective": "binary",
            "verbose": -1,
            "num_leaves": 3
        }
28
29
30
31

    @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
    def test_plot_importance(self):
        gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
32
        ax0 = lgb.plot_importance(gbm0)
wxchan's avatar
wxchan committed
33
        self.assertIsInstance(ax0, matplotlib.axes.Axes)
34
35
36
37
38
39
        self.assertEqual(ax0.get_title(), 'Feature importance')
        self.assertEqual(ax0.get_xlabel(), 'Feature importance')
        self.assertEqual(ax0.get_ylabel(), 'Features')
        self.assertLessEqual(len(ax0.patches), 30)

        gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
40
        gbm1.fit(self.X_train, self.y_train)
41
42

        ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
wxchan's avatar
wxchan committed
43
        self.assertIsInstance(ax1, matplotlib.axes.Axes)
44
45
46
47
48
49
50
        self.assertEqual(ax1.get_title(), 't')
        self.assertEqual(ax1.get_xlabel(), 'x')
        self.assertEqual(ax1.get_ylabel(), 'y')
        self.assertLessEqual(len(ax1.patches), 30)
        for patch in ax1.patches:
            self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.))  # red

wxchan's avatar
wxchan committed
51
        ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'],
52
                                  title=None, xlabel=None, ylabel=None)
wxchan's avatar
wxchan committed
53
        self.assertIsInstance(ax2, matplotlib.axes.Axes)
54
55
56
57
58
59
60
61
62
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')
        self.assertLessEqual(len(ax2.patches), 30)
        self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.))  # r
        self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.))  # y
        self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.))  # g
        self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.))  # b

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
    def test_plot_split_value_histogram(self):
        gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
        ax0 = lgb.plot_split_value_histogram(gbm0, 27)
        self.assertIsInstance(ax0, matplotlib.axes.Axes)
        self.assertEqual(ax0.get_title(), 'Split value histogram for feature with index 27')
        self.assertEqual(ax0.get_xlabel(), 'Feature split value')
        self.assertEqual(ax0.get_ylabel(), 'Count')
        self.assertLessEqual(len(ax0.patches), 2)

        gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
        gbm1.fit(self.X_train, self.y_train)

        ax1 = lgb.plot_split_value_histogram(gbm1, gbm1.booster_.feature_name()[27], figsize=(10, 5),
                                             title='Histogram for feature @index/name@ @feature@',
                                             xlabel='x', ylabel='y', color='r')
        self.assertIsInstance(ax1, matplotlib.axes.Axes)
        self.assertEqual(ax1.get_title(),
                         'Histogram for feature name {}'.format(gbm1.booster_.feature_name()[27]))
        self.assertEqual(ax1.get_xlabel(), 'x')
        self.assertEqual(ax1.get_ylabel(), 'y')
        self.assertLessEqual(len(ax1.patches), 2)
        for patch in ax1.patches:
            self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.))  # red

        ax2 = lgb.plot_split_value_histogram(gbm0, 27, bins=10, color=['r', 'y', 'g', 'b'],
                                             title=None, xlabel=None, ylabel=None)
        self.assertIsInstance(ax2, matplotlib.axes.Axes)
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')
        self.assertEqual(len(ax2.patches), 10)
        self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.))  # r
        self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.))  # y
        self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.))  # g
        self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.))  # b

        self.assertRaises(ValueError, lgb.plot_split_value_histogram, gbm0, 0)  # was not used in splitting

102
    @unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed')
wxchan's avatar
wxchan committed
103
    def test_plot_tree(self):
104
105
        gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
        gbm.fit(self.X_train, self.y_train, verbose=False)
wxchan's avatar
wxchan committed
106

107
        self.assertRaises(IndexError, lgb.plot_tree, gbm, tree_index=83)
108

109
110
111
112
113
114
115
116
        ax = lgb.plot_tree(gbm, tree_index=3, figsize=(15, 8), show_info=['split_gain'])
        self.assertIsInstance(ax, matplotlib.axes.Axes)
        w, h = ax.axes.get_figure().get_size_inches()
        self.assertEqual(int(w), 15)
        self.assertEqual(int(h), 8)

    @unittest.skipIf(not GRAPHVIZ_INSTALLED, 'graphviz is not installed')
    def test_create_tree_digraph(self):
117
118
        constraints = [-1, 1] * int(self.X_train.shape[1] / 2)
        gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, monotone_constraints=constraints)
119
120
121
122
123
        gbm.fit(self.X_train, self.y_train, verbose=False)

        self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)

        graph = lgb.create_tree_digraph(gbm, tree_index=3,
124
                                        show_info=['split_gain', 'internal_value', 'internal_weight'],
125
126
127
128
129
130
131
132
133
134
                                        name='Tree4', node_attr={'color': 'red'})
        graph.render(view=False)
        self.assertIsInstance(graph, graphviz.Digraph)
        self.assertEqual(graph.name, 'Tree4')
        self.assertEqual(graph.filename, 'Tree4.gv')
        self.assertEqual(len(graph.node_attr), 1)
        self.assertEqual(graph.node_attr['color'], 'red')
        self.assertEqual(len(graph.graph_attr), 0)
        self.assertEqual(len(graph.edge_attr), 0)
        graph_body = ''.join(graph.body)
135
136
137
138
139
140
141
142
        self.assertIn('leaf', graph_body)
        self.assertIn('gain', graph_body)
        self.assertIn('value', graph_body)
        self.assertIn('weight', graph_body)
        self.assertIn('#ffdddd', graph_body)
        self.assertIn('#ddffdd', graph_body)
        self.assertNotIn('data', graph_body)
        self.assertNotIn('count', graph_body)
143
144
145
146
147

    @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
    def test_plot_metrics(self):
        test_data = lgb.Dataset(self.X_test, self.y_test, reference=self.train_data)
        self.params.update({"metric": {"binary_logloss", "binary_error"}})
148
149

        evals_result0 = {}
150
151
        gbm0 = lgb.train(self.params, self.train_data,
                         valid_sets=[self.train_data, test_data],
152
153
154
155
156
157
158
159
160
161
162
163
164
                         valid_names=['v1', 'v2'],
                         num_boost_round=10,
                         evals_result=evals_result0,
                         verbose_eval=False)
        ax0 = lgb.plot_metric(evals_result0)
        self.assertIsInstance(ax0, matplotlib.axes.Axes)
        self.assertEqual(ax0.get_title(), 'Metric during training')
        self.assertEqual(ax0.get_xlabel(), 'Iterations')
        self.assertIn(ax0.get_ylabel(), {'binary_logloss', 'binary_error'})
        ax0 = lgb.plot_metric(evals_result0, metric='binary_error')
        ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])

        evals_result1 = {}
165
        gbm1 = lgb.train(self.params, self.train_data,
166
167
168
169
170
171
                         num_boost_round=10,
                         evals_result=evals_result1,
                         verbose_eval=False)
        self.assertRaises(ValueError, lgb.plot_metric, evals_result1)

        gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
172
        gbm2.fit(self.X_train, self.y_train, eval_set=[(self.X_test, self.y_test)], verbose=False)
173
174
175
176
177
        ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
        self.assertIsInstance(ax2, matplotlib.axes.Axes)
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')