test_sklearn.py 4.95 KB
Newer Older
wxchan's avatar
wxchan committed
1
2
3
# coding: utf-8
# pylint: skip-file
import os, unittest
Guolin Ke's avatar
Guolin Ke committed
4
5
import numpy as np
import lightgbm as lgb
wxchan's avatar
wxchan committed
6
7
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.datasets import load_breast_cancer, load_boston, load_digits, load_svmlight_file
wxchan's avatar
wxchan committed
8
9
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.base import clone
wxchan's avatar
wxchan committed
10
from sklearn.externals import joblib
wxchan's avatar
wxchan committed
11
12

def test_template(X_y=load_boston(True), model=lgb.LGBMRegressor,
wxchan's avatar
wxchan committed
13
14
15
16
                  feval=mean_squared_error, num_round=100,
                  custom_obj=None, predict_proba=False,
                  return_data=False, return_model=False):
    X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
wxchan's avatar
wxchan committed
17
    if return_data: return X_train, X_test, y_train, y_test
wxchan's avatar
wxchan committed
18
19
20
    arguments = {'n_estimators' : num_round, 'silent' : True}
    if custom_obj: arguments['objective'] = custom_obj
    gbm = model(**arguments)
wxchan's avatar
wxchan committed
21
22
    gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=10, verbose=False)
    if return_model: return gbm
wxchan's avatar
wxchan committed
23
    else: return feval(y_test, gbm.predict_proba(X_test) if predict_proba else gbm.predict(X_test))
wxchan's avatar
wxchan committed
24
25
26
27
28

class TestSklearn(unittest.TestCase):

    def test_binary(self):
        X_y= load_breast_cancer(True)
wxchan's avatar
wxchan committed
29
        ret = test_template(X_y, lgb.LGBMClassifier, log_loss, predict_proba=True)
wxchan's avatar
wxchan committed
30
31
32
33
        self.assertLess(ret, 0.15)

    def test_regreesion(self):
        self.assertLess(test_template() ** 0.5, 4)
wxchan's avatar
wxchan committed
34

wxchan's avatar
wxchan committed
35
36
37
38
    def test_multiclass(self):
        X_y = load_digits(10, True)
        def multi_error(y_true, y_pred):
            return np.mean(y_true != y_pred)
wxchan's avatar
wxchan committed
39
        ret = test_template(X_y, lgb.LGBMClassifier, multi_error)
wxchan's avatar
wxchan committed
40
        self.assertLess(ret, 0.2)
wxchan's avatar
wxchan committed
41

wxchan's avatar
wxchan committed
42
43
44
45
    def test_lambdarank(self):
        X_train, y_train = load_svmlight_file('../../examples/lambdarank/rank.train')
        X_test, y_test = load_svmlight_file('../../examples/lambdarank/rank.test')
        q_train = np.loadtxt('../../examples/lambdarank/rank.train.query')
46
47
48
49
50
51
52
53
        q_test = np.loadtxt('../../examples/lambdarank/rank.test.query')
        lgb_model = lgb.LGBMRanker().fit(X_train, y_train,
                                         group=q_train,
                                         eval_set=[(X_test, y_test)],
                                         eval_group=[q_test],
                                         eval_at=[1],
                                         verbose=False,
                                         callbacks=[lgb.reset_parameter(learning_rate=lambda x: 0.95 ** x * 0.1)])
wxchan's avatar
wxchan committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    def test_regression_with_custom_objective(self):
        def objective_ls(y_true, y_pred):
            grad = (y_pred - y_true)
            hess = np.ones(len(y_true))
            return grad, hess
        ret = test_template(custom_obj=objective_ls)
        self.assertLess(ret, 100)

    def test_binary_classification_with_custom_objective(self):
        def logregobj(y_true, y_pred):
            y_pred = 1.0 / (1.0 + np.exp(-y_pred))
            grad = y_pred - y_true
            hess = y_pred * (1.0 - y_pred)
            return grad, hess
        X_y = load_digits(2, True)
        def binary_error(y_test, y_pred):
            return np.mean([int(p > 0.5) != y for y, p in zip(y_test, y_pred)])
        ret = test_template(X_y, lgb.LGBMClassifier, feval=binary_error, custom_obj=logregobj)
        self.assertLess(ret, 0.1)

75
76
77
78
79
80
    def test_dart(self):
        X_train, X_test, y_train, y_test = test_template(return_data=True)
        gbm = lgb.LGBMRegressor(boosting_type='dart')
        gbm.fit(X_train, y_train)
        self.assertLessEqual(gbm.score(X_train, y_train), 1.)

wxchan's avatar
wxchan committed
81
82
    def test_grid_search(self):
        X_train, X_test, y_train, y_test = test_template(return_data=True)
83
84
85
        params = {'boosting_type': ['dart', 'gbdt'],
                  'n_estimators': [15, 20], 'drop_rate':[0.1, 0.2]}
        gbm = GridSearchCV(lgb.LGBMRegressor(), params, cv=3)
wxchan's avatar
wxchan committed
86
        gbm.fit(X_train, y_train)
87
        self.assertIn(gbm.best_params_['n_estimators'], [15, 20])
wxchan's avatar
wxchan committed
88
89
90
91
92

    def test_clone(self):
        gbm = test_template(return_model=True)
        gbm_clone = clone(gbm)

wxchan's avatar
wxchan committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def test_joblib(self):
        gbm = test_template(num_round=10, return_model=True)
        joblib.dump(gbm, 'lgb.pkl')
        gbm_pickle = joblib.load('lgb.pkl')
        self.assertDictEqual(gbm.get_params(), gbm_pickle.get_params())
        X_train, X_test, y_train, y_test = test_template(return_data=True)
        gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
        gbm_pickle.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
        self.assertDictEqual(gbm.evals_result(), gbm_pickle.evals_result())
        pred_origin = gbm.predict(X_test)
        pred_pickle = gbm_pickle.predict(X_test)
        self.assertEqual(len(pred_origin), len(pred_pickle))
        for preds in zip(pred_origin, pred_pickle):
            self.assertAlmostEqual(*preds, places=5)

wxchan's avatar
wxchan committed
108
109
110
print("----------------------------------------------------------------------")
print("running test_sklearn.py")
unittest.main()