test_.py 9.2 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
Guolin Ke's avatar
Guolin Ke committed
2
import ctypes
3
4
from os import environ
from pathlib import Path
5
6
from platform import system

Guolin Ke's avatar
Guolin Ke committed
7
8
9
import numpy as np
from scipy import sparse

wxchan's avatar
wxchan committed
10

Guolin Ke's avatar
Guolin Ke committed
11
def find_lib_path():
12
    if environ.get('LIGHTGBM_BUILD_DOC', False):
Guolin Ke's avatar
Guolin Ke committed
13
14
15
        # we don't need lib_lightgbm while building docs
        return []

16
    curr_path = Path(__file__).parent.absolute()
17
    dll_path = [curr_path,
18
19
20
21
                curr_path.parents[1],
                curr_path.parents[1] / 'python-package' / 'lightgbm' / 'compile',
                curr_path.parents[1] / 'python-package' / 'compile',
                curr_path.parents[1] / 'lib']
22
    if system() in ('Windows', 'Microsoft'):
23
24
25
26
27
        dll_path.append(curr_path.parents[1] / 'python-package' / 'compile' / 'Release/')
        dll_path.append(curr_path.parents[1] / 'python-package' / 'compile' / 'windows' / 'x64' / 'DLL')
        dll_path.append(curr_path.parents[1] / 'Release')
        dll_path.append(curr_path.parents[1] / 'windows' / 'x64' / 'DLL')
        dll_path = [p / 'lib_lightgbm.dll' for p in dll_path]
Guolin Ke's avatar
Guolin Ke committed
28
    else:
29
30
        dll_path = [p / 'lib_lightgbm.so' for p in dll_path]
    lib_path = [str(p) for p in dll_path if p.is_file()]
Guolin Ke's avatar
Guolin Ke committed
31
    if not lib_path:
32
        dll_path_joined = '\n'.join(map(str, dll_path))
33
        raise Exception(f'Cannot find lightgbm library file in following paths:\n{dll_path_joined}')
Guolin Ke's avatar
Guolin Ke committed
34
35
36
37
38
39
40
41
    return lib_path


def LoadDll():
    lib_path = find_lib_path()
    if len(lib_path) == 0:
        return None
    lib = ctypes.cdll.LoadLibrary(lib_path[0])
Guolin Ke's avatar
Guolin Ke committed
42
43
    return lib

wxchan's avatar
wxchan committed
44

Guolin Ke's avatar
Guolin Ke committed
45
46
LIB = LoadDll()

Guolin Ke's avatar
Guolin Ke committed
47
48
LIB.LGBM_GetLastError.restype = ctypes.c_char_p

49
50
51
52
53
54
dtype_float32 = 0
dtype_float64 = 1
dtype_int32 = 2
dtype_int64 = 3


Guolin Ke's avatar
Guolin Ke committed
55
def c_str(string):
56
    return ctypes.c_char_p(string.encode('utf-8'))
Guolin Ke's avatar
Guolin Ke committed
57

wxchan's avatar
wxchan committed
58

59
def load_from_file(filename, reference):
60
    ref = None
wxchan's avatar
wxchan committed
61
    if reference is not None:
Guolin Ke's avatar
Guolin Ke committed
62
        ref = reference
63
    handle = ctypes.c_void_p()
wxchan's avatar
wxchan committed
64
    LIB.LGBM_DatasetCreateFromFile(
65
        c_str(str(filename)),
wxchan's avatar
wxchan committed
66
        c_str('max_bin=15'),
67
68
        ref,
        ctypes.byref(handle))
Guolin Ke's avatar
Guolin Ke committed
69
    print(LIB.LGBM_GetLastError())
70
    num_data = ctypes.c_int(0)
wxchan's avatar
wxchan committed
71
    LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
72
    num_feature = ctypes.c_int(0)
wxchan's avatar
wxchan committed
73
    LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
74
    print(f'#data: {num_data.value} #feature: {num_feature.value}')
75
76
    return handle

wxchan's avatar
wxchan committed
77

78
def save_to_binary(handle, filename):
79
80
81
    LIB.LGBM_DatasetSaveBinary(handle, c_str(filename))


82
def load_from_csr(filename, reference):
83
84
85
    data = np.loadtxt(str(filename), dtype=np.float64)
    csr = sparse.csr_matrix(data[:, 1:])
    label = data[:, 0].astype(np.float32)
Guolin Ke's avatar
Guolin Ke committed
86
87
    handle = ctypes.c_void_p()
    ref = None
wxchan's avatar
wxchan committed
88
    if reference is not None:
Guolin Ke's avatar
Guolin Ke committed
89
        ref = reference
Guolin Ke's avatar
Guolin Ke committed
90

wxchan's avatar
wxchan committed
91
    LIB.LGBM_DatasetCreateFromCSR(
92
93
94
95
96
        csr.indptr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
        ctypes.c_int(dtype_int32),
        csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
        csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        ctypes.c_int(dtype_float64),
97
98
99
        ctypes.c_int64(len(csr.indptr)),
        ctypes.c_int64(len(csr.data)),
        ctypes.c_int64(csr.shape[1]),
wxchan's avatar
wxchan committed
100
101
102
        c_str('max_bin=15'),
        ref,
        ctypes.byref(handle))
103
    num_data = ctypes.c_int(0)
wxchan's avatar
wxchan committed
104
    LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
105
    num_feature = ctypes.c_int(0)
wxchan's avatar
wxchan committed
106
    LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
107
108
109
110
111
112
    LIB.LGBM_DatasetSetField(
        handle,
        c_str('label'),
        label.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        ctypes.c_int(len(label)),
        ctypes.c_int(dtype_float32))
113
    print(f'#data: {num_data.value} #feature: {num_feature.value}')
114
115
    return handle

wxchan's avatar
wxchan committed
116

117
def load_from_csc(filename, reference):
118
119
120
    data = np.loadtxt(str(filename), dtype=np.float64)
    csc = sparse.csc_matrix(data[:, 1:])
    label = data[:, 0].astype(np.float32)
121
122
    handle = ctypes.c_void_p()
    ref = None
wxchan's avatar
wxchan committed
123
    if reference is not None:
Guolin Ke's avatar
Guolin Ke committed
124
        ref = reference
Guolin Ke's avatar
Guolin Ke committed
125

wxchan's avatar
wxchan committed
126
    LIB.LGBM_DatasetCreateFromCSC(
127
128
129
130
131
132
133
134
        csc.indptr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
        ctypes.c_int(dtype_int32),
        csc.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
        csc.data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        ctypes.c_int(dtype_float64),
        ctypes.c_int64(len(csc.indptr)),
        ctypes.c_int64(len(csc.data)),
        ctypes.c_int64(csc.shape[0]),
wxchan's avatar
wxchan committed
135
136
137
        c_str('max_bin=15'),
        ref,
        ctypes.byref(handle))
138
    num_data = ctypes.c_int(0)
wxchan's avatar
wxchan committed
139
    LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
140
    num_feature = ctypes.c_int(0)
wxchan's avatar
wxchan committed
141
    LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
142
143
144
145
146
147
    LIB.LGBM_DatasetSetField(
        handle,
        c_str('label'),
        label.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        ctypes.c_int(len(label)),
        ctypes.c_int(dtype_float32))
148
    print(f'#data: {num_data.value} #feature: {num_feature.value}')
149
150
    return handle

wxchan's avatar
wxchan committed
151

152
def load_from_mat(filename, reference):
153
154
155
    mat = np.loadtxt(str(filename), dtype=np.float64)
    label = mat[:, 0].astype(np.float32)
    mat = mat[:, 1:]
156
    data = np.array(mat.reshape(mat.size), dtype=np.float64, copy=False)
157
158
    handle = ctypes.c_void_p()
    ref = None
wxchan's avatar
wxchan committed
159
    if reference is not None:
Guolin Ke's avatar
Guolin Ke committed
160
        ref = reference
Guolin Ke's avatar
Guolin Ke committed
161

162
    LIB.LGBM_DatasetCreateFromMat(
163
164
165
166
167
        data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        ctypes.c_int(dtype_float64),
        ctypes.c_int32(mat.shape[0]),
        ctypes.c_int32(mat.shape[1]),
        ctypes.c_int(1),
wxchan's avatar
wxchan committed
168
169
170
        c_str('max_bin=15'),
        ref,
        ctypes.byref(handle))
171
    num_data = ctypes.c_int(0)
wxchan's avatar
wxchan committed
172
    LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data))
173
    num_feature = ctypes.c_int(0)
wxchan's avatar
wxchan committed
174
    LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature))
175
176
177
178
179
180
    LIB.LGBM_DatasetSetField(
        handle,
        c_str('label'),
        label.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        ctypes.c_int(len(label)),
        ctypes.c_int(dtype_float32))
181
    print(f'#data: {num_data.value} #feature: {num_feature.value}')
Guolin Ke's avatar
Guolin Ke committed
182
    return handle
wxchan's avatar
wxchan committed
183
184


185
def free_dataset(handle):
186
187
    LIB.LGBM_DatasetFree(handle)

wxchan's avatar
wxchan committed
188

189
def test_dataset():
190
191
192
    binary_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'binary_classification'
    train = load_from_file(binary_example_dir / 'binary.train', None)
    test = load_from_mat(binary_example_dir / 'binary.test', train)
193
    free_dataset(test)
194
    test = load_from_csr(binary_example_dir / 'binary.test', train)
195
    free_dataset(test)
196
    test = load_from_csc(binary_example_dir / 'binary.test', train)
197
198
199
200
201
    free_dataset(test)
    save_to_binary(train, 'train.binary.bin')
    free_dataset(train)
    train = load_from_file('train.binary.bin', None)
    free_dataset(train)
wxchan's avatar
wxchan committed
202
203


204
def test_booster():
205
206
207
    binary_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'binary_classification'
    train = load_from_mat(binary_example_dir / 'binary.train', None)
    test = load_from_mat(binary_example_dir / 'binary.test', train)
208
    booster = ctypes.c_void_p()
209
210
211
212
    LIB.LGBM_BoosterCreate(
        train,
        c_str("app=binary metric=auc num_leaves=31 verbose=0"),
        ctypes.byref(booster))
213
    LIB.LGBM_BoosterAddValidData(booster, test)
214
    is_finished = ctypes.c_int(0)
215
    for i in range(1, 51):
wxchan's avatar
wxchan committed
216
        LIB.LGBM_BoosterUpdateOneIter(booster, ctypes.byref(is_finished))
Guolin Ke's avatar
Guolin Ke committed
217
        result = np.array([0.0], dtype=np.float64)
218
        out_len = ctypes.c_int(0)
219
220
        LIB.LGBM_BoosterGetEval(
            booster,
221
            ctypes.c_int(0),
222
223
            ctypes.byref(out_len),
            result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
wxchan's avatar
wxchan committed
224
        if i % 10 == 0:
225
            print(f'{i} iteration test AUC {result[0]:.6f}')
226
227
228
229
230
231
    LIB.LGBM_BoosterSaveModel(
        booster,
        ctypes.c_int(0),
        ctypes.c_int(-1),
        ctypes.c_int(0),
        c_str('model.txt'))
232
    LIB.LGBM_BoosterFree(booster)
233
234
    free_dataset(train)
    free_dataset(test)
235
    booster2 = ctypes.c_void_p()
236
    num_total_model = ctypes.c_int(0)
237
238
239
240
    LIB.LGBM_BoosterCreateFromModelfile(
        c_str('model.txt'),
        ctypes.byref(num_total_model),
        ctypes.byref(booster2))
241
242
    data = np.loadtxt(str(binary_example_dir / 'binary.test'), dtype=np.float64)
    mat = data[:, 1:]
243
    preb = np.empty(mat.shape[0], dtype=np.float64)
244
245
    num_preb = ctypes.c_int64(0)
    data = np.array(mat.reshape(mat.size), dtype=np.float64, copy=False)
wxchan's avatar
wxchan committed
246
247
    LIB.LGBM_BoosterPredictForMat(
        booster2,
248
249
250
251
252
253
254
255
        data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        ctypes.c_int(dtype_float64),
        ctypes.c_int32(mat.shape[0]),
        ctypes.c_int32(mat.shape[1]),
        ctypes.c_int(1),
        ctypes.c_int(1),
        ctypes.c_int(0),
        ctypes.c_int(25),
256
        c_str(''),
Guolin Ke's avatar
Guolin Ke committed
257
        ctypes.byref(num_preb),
258
        preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
259
260
    LIB.LGBM_BoosterPredictForFile(
        booster2,
261
        c_str(str(binary_example_dir / 'binary.test')),
262
263
264
265
        ctypes.c_int(0),
        ctypes.c_int(0),
        ctypes.c_int(0),
        ctypes.c_int(25),
266
267
268
269
        c_str(''),
        c_str('preb.txt'))
    LIB.LGBM_BoosterPredictForFile(
        booster2,
270
        c_str(str(binary_example_dir / 'binary.test')),
271
272
273
274
        ctypes.c_int(0),
        ctypes.c_int(0),
        ctypes.c_int(10),
        ctypes.c_int(25),
275
276
        c_str(''),
        c_str('preb.txt'))
277
    LIB.LGBM_BoosterFree(booster2)