test_plotting.py 2.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# coding: utf-8
# pylint: skip-file
import unittest

import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

try:
    from matplotlib.axes import Axes
    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False


class TestBasic(unittest.TestCase):

    @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib not installed')
    def test_plot_importance(self):
        X_train, _, y_train, _ = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
        train_data = lgb.Dataset(X_train, y_train)

        params = {
            "objective": "binary",
            "verbose": -1,
            "num_leaves": 3
        }
        gbm0 = lgb.train(params, train_data, num_boost_round=10)
        ax0 = lgb.plot_importance(gbm0)
        self.assertIsInstance(ax0, Axes)
        self.assertEqual(ax0.get_title(), 'Feature importance')
        self.assertEqual(ax0.get_xlabel(), 'Feature importance')
        self.assertEqual(ax0.get_ylabel(), 'Features')
        self.assertLessEqual(len(ax0.patches), 30)

        gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
        gbm1.fit(X_train, y_train)

        ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
        self.assertIsInstance(ax1, Axes)
        self.assertEqual(ax1.get_title(), 't')
        self.assertEqual(ax1.get_xlabel(), 'x')
        self.assertEqual(ax1.get_ylabel(), 'y')
        self.assertLessEqual(len(ax1.patches), 30)
        for patch in ax1.patches:
            self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.))  # red

        ax2 = lgb.plot_importance(gbm0.feature_importance(),
                                  color=['r', 'y', 'g', 'b'],
                                  title=None, xlabel=None, ylabel=None)
        self.assertIsInstance(ax2, Axes)
        self.assertEqual(ax2.get_title(), '')
        self.assertEqual(ax2.get_xlabel(), '')
        self.assertEqual(ax2.get_ylabel(), '')
        self.assertLessEqual(len(ax2.patches), 30)
        self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.))  # r
        self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.))  # y
        self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.))  # g
        self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.))  # b


print("----------------------------------------------------------------------")
print("running test_plotting.py")
unittest.main()