test_plotting.py 4.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
# coding: utf-8
# pylint: skip-file
import unittest

import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

try:
wxchan's avatar
wxchan committed
10
11
12
    import matplotlib
    matplotlib.use('Agg')
    matplotlib_installed = True
13
except ImportError:
wxchan's avatar
wxchan committed
14
    matplotlib_installed = False
15
16
17
18


class TestBasic(unittest.TestCase):

19
    @unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
20
21
22
23
24
25
26
27
28
29
30
    def test_plot_importance(self):
        X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
        train_data = lgb.Dataset(X_train, y_train)

        params = {
            "objective": "binary",
            "verbose": -1,
            "num_leaves": 3
        }
        gbm0 = lgb.train(params, train_data, num_boost_round=10)
        ax0 = lgb.plot_importance(gbm0)
wxchan's avatar
wxchan committed
31
        self.assertIsInstance(ax0, matplotlib.axes.Axes)
32
33
34
35
36
37
38
39
40
        self.assertEqual(ax0.get_title(), 'Feature importance')
        self.assertEqual(ax0.get_xlabel(), 'Feature importance')
        self.assertEqual(ax0.get_ylabel(), 'Features')
        self.assertLessEqual(len(ax0.patches), 30)

        gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
        gbm1.fit(X_train, y_train)

        ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
wxchan's avatar
wxchan committed
41
        self.assertIsInstance(ax1, matplotlib.axes.Axes)
42
43
44
45
46
47
48
        self.assertEqual(ax1.get_title(), 't')
        self.assertEqual(ax1.get_xlabel(), 'x')
        self.assertEqual(ax1.get_ylabel(), 'y')
        self.assertLessEqual(len(ax1.patches), 30)
        for patch in ax1.patches:
            self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.))  # red

wxchan's avatar
wxchan committed
49
        ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'],
50
                                  title=None, xlabel=None, ylabel=None)
wxchan's avatar
wxchan committed
51
        self.assertIsInstance(ax2, matplotlib.axes.Axes)
52
53
54
55
56
57
58
59
60
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')
        self.assertLessEqual(len(ax2.patches), 30)
        self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.))  # r
        self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.))  # y
        self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.))  # g
        self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.))  # b

wxchan's avatar
wxchan committed
61
62
63
64
    @unittest.skip('Graphviz are not executables on Travis')
    def test_plot_tree(self):
        pass

65
    @unittest.skipIf(not matplotlib_installed, 'matplotlib is not installed')
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def test_plot_metrics(self):
        X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
        train_data = lgb.Dataset(X_train, y_train)
        test_data = lgb.Dataset(X_test, y_test, reference=train_data)

        params = {
            "objective": "binary",
            "metric": {"binary_logloss", "binary_error"},
            "verbose": -1,
            "num_leaves": 3
        }

        evals_result0 = {}
        gbm0 = lgb.train(params, train_data,
                         valid_sets=[train_data, test_data],
                         valid_names=['v1', 'v2'],
                         num_boost_round=10,
                         evals_result=evals_result0,
                         verbose_eval=False)
        ax0 = lgb.plot_metric(evals_result0)
        self.assertIsInstance(ax0, matplotlib.axes.Axes)
        self.assertEqual(ax0.get_title(), 'Metric during training')
        self.assertEqual(ax0.get_xlabel(), 'Iterations')
        self.assertIn(ax0.get_ylabel(), {'binary_logloss', 'binary_error'})
        ax0 = lgb.plot_metric(evals_result0, metric='binary_error')
        ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])

        evals_result1 = {}
        gbm1 = lgb.train(params, train_data,
                         num_boost_round=10,
                         evals_result=evals_result1,
                         verbose_eval=False)
        self.assertRaises(ValueError, lgb.plot_metric, evals_result1)

        gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
        gbm2.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
        ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
        self.assertIsInstance(ax2, matplotlib.axes.Axes)
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')