test_basic.py 1.85 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: skip-file
3
import unittest, tempfile, os
wxchan's avatar
wxchan committed
4
import numpy as np
wxchan's avatar
wxchan committed
5
6
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
wxchan's avatar
wxchan committed
7
8
import lightgbm as lgb

wxchan's avatar
wxchan committed
9
class TestBasic(unittest.TestCase):
wxchan's avatar
wxchan committed
10

wxchan's avatar
wxchan committed
11
    def test(self):
12
        X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1)
wxchan's avatar
wxchan committed
13
14
        train_data = lgb.Dataset(X_train, max_bin=255, label=y_train)
        valid_data = train_data.create_valid(X_test, label=y_test)
wxchan's avatar
wxchan committed
15

wxchan's avatar
wxchan committed
16
17
18
19
20
21
22
23
24
        params = {
            "objective" : "binary",
            "metric" : "auc",
            "min_data" : 1,
            "num_leaves" : 15,
            "verbose" : -1
        }
        bst = lgb.Booster(params, train_data)
        bst.add_valid(valid_data, "valid_1")
wxchan's avatar
wxchan committed
25

wxchan's avatar
wxchan committed
26
27
28
29
30
31
32
        for i in range(30):
            bst.update()
            if i % 10 == 0:
                print(bst.eval_train(), bst.eval_valid())
        bst.save_model("model.txt")
        pred_from_matr = bst.predict(X_test)
        with tempfile.NamedTemporaryFile() as f:
33
34
            tname = f.name
        with open(tname, "w+b") as f:
wxchan's avatar
wxchan committed
35
            np.savetxt(f, X_test, delimiter=',')
36
37
        pred_from_file = bst.predict(tname)
        os.remove(tname)
wxchan's avatar
wxchan committed
38
39
        self.assertEqual(len(pred_from_matr), len(pred_from_file))
        for preds in zip(pred_from_matr, pred_from_file):
40
41
42
43
44
45
46
            self.assertAlmostEqual(*preds, places=15)
        #check saved model persistence
        bst = lgb.Booster(params, model_file="model.txt")
        pred_from_model_file = bst.predict(X_test)
        self.assertEqual(len(pred_from_matr), len(pred_from_model_file))
        for preds in zip(pred_from_matr, pred_from_model_file):
            self.assertAlmostEqual(*preds, places=15)
wxchan's avatar
wxchan committed
47

wxchan's avatar
wxchan committed
48
49
50
print("----------------------------------------------------------------------")
print("running test_basic.py")
unittest.main()