test_plotting.py 4.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
# coding: utf-8
# pylint: skip-file
import unittest

import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

try:
wxchan's avatar
wxchan committed
10
11
12
    import matplotlib
    matplotlib.use('Agg')
    matplotlib_installed = True
13
except ImportError:
wxchan's avatar
wxchan committed
14
    matplotlib_installed = False
15
16
17
18


class TestBasic(unittest.TestCase):

wxchan's avatar
wxchan committed
19
    @unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
20
21
22
23
24
25
26
27
28
29
30
    def test_plot_importance(self):
        X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
        train_data = lgb.Dataset(X_train, y_train)

        params = {
            "objective": "binary",
            "verbose": -1,
            "num_leaves": 3
        }
        gbm0 = lgb.train(params, train_data, num_boost_round=10)
        ax0 = lgb.plot_importance(gbm0)
wxchan's avatar
wxchan committed
31
        self.assertIsInstance(ax0, matplotlib.axes.Axes)
32
33
34
35
36
37
38
39
40
        self.assertEqual(ax0.get_title(), 'Feature importance')
        self.assertEqual(ax0.get_xlabel(), 'Feature importance')
        self.assertEqual(ax0.get_ylabel(), 'Features')
        self.assertLessEqual(len(ax0.patches), 30)

        gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
        gbm1.fit(X_train, y_train)

        ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
wxchan's avatar
wxchan committed
41
        self.assertIsInstance(ax1, matplotlib.axes.Axes)
42
43
44
45
46
47
48
49
50
51
        self.assertEqual(ax1.get_title(), 't')
        self.assertEqual(ax1.get_xlabel(), 'x')
        self.assertEqual(ax1.get_ylabel(), 'y')
        self.assertLessEqual(len(ax1.patches), 30)
        for patch in ax1.patches:
            self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.))  # red

        ax2 = lgb.plot_importance(gbm0.feature_importance(),
                                  color=['r', 'y', 'g', 'b'],
                                  title=None, xlabel=None, ylabel=None)
wxchan's avatar
wxchan committed
52
        self.assertIsInstance(ax2, matplotlib.axes.Axes)
53
54
55
56
57
58
59
60
61
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')
        self.assertLessEqual(len(ax2.patches), 30)
        self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.))  # r
        self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.))  # y
        self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.))  # g
        self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.))  # b

wxchan's avatar
wxchan committed
62
63
64
65
    @unittest.skip('Graphviz are not executables on Travis')
    def test_plot_tree(self):
        pass

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    @unittest.skipIf(not matplotlib_installed, 'matplotlib not installed')
    def test_plot_metrics(self):
        X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
        train_data = lgb.Dataset(X_train, y_train)
        test_data = lgb.Dataset(X_test, y_test, reference=train_data)

        params = {
            "objective": "binary",
            "metric": {"binary_logloss", "binary_error"},
            "verbose": -1,
            "num_leaves": 3
        }

        evals_result0 = {}
        gbm0 = lgb.train(params, train_data,
                         valid_sets=[train_data, test_data],
                         valid_names=['v1', 'v2'],
                         num_boost_round=10,
                         evals_result=evals_result0,
                         verbose_eval=False)
        ax0 = lgb.plot_metric(evals_result0)
        self.assertIsInstance(ax0, matplotlib.axes.Axes)
        self.assertEqual(ax0.get_title(), 'Metric during training')
        self.assertEqual(ax0.get_xlabel(), 'Iterations')
        self.assertIn(ax0.get_ylabel(), {'binary_logloss', 'binary_error'})
        ax0 = lgb.plot_metric(evals_result0, metric='binary_error')
        ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])

        evals_result1 = {}
        gbm1 = lgb.train(params, train_data,
                         num_boost_round=10,
                         evals_result=evals_result1,
                         verbose_eval=False)
        self.assertRaises(ValueError, lgb.plot_metric, evals_result1)

        gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
        gbm2.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
        ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
        self.assertIsInstance(ax2, matplotlib.axes.Axes)
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')

109
110
111
112

print("----------------------------------------------------------------------")
print("running test_plotting.py")
unittest.main()