Commit a44bc1ca authored by Guolin Ke's avatar Guolin Ke
Browse files

update travis, clean code

parent 41c0370b
...@@ -5,14 +5,22 @@ dist: trusty ...@@ -5,14 +5,22 @@ dist: trusty
before_install: before_install:
- test -n $CC && unset CC - test -n $CC && unset CC
- test -n $CXX && unset CXX - test -n $CXX && unset CXX
- wget -O conda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
- chmod +x conda.sh
- bash conda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
install: install:
- sudo apt-get install -y libopenmpi-dev openmpi-bin build-essential - sudo apt-get install -y libopenmpi-dev openmpi-bin build-essential
- conda install --yes atlas numpy scipy scikit-learn
script: script:
- cd $TRAVIS_BUILD_DIR - cd $TRAVIS_BUILD_DIR
- mkdir build && cd build && cmake .. && make -j - mkdir build && cd build && cmake .. && make -j
- cd $TRAVIS_BUILD_DIR/tests/c_api_test && python test.py
- cd $TRAVIS_BUILD_DIR - cd $TRAVIS_BUILD_DIR
- rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j - rm -rf build && mkdir build && cd build && cmake -DUSE_MPI=ON ..&& make -j
......
...@@ -26,10 +26,14 @@ ...@@ -26,10 +26,14 @@
typedef void* DatesetHandle; typedef void* DatesetHandle;
typedef void* BoosterHandle; typedef void* BoosterHandle;
#define dtype_float32 (0) #define C_API_DTYPE_FLOAT32 (0)
#define dtype_float64 (1) #define C_API_DTYPE_FLOAT64 (1)
#define dtype_int32 (2) #define C_API_DTYPE_INT32 (2)
#define dtype_int64 (3) #define C_API_DTYPE_INT64 (3)
#define C_API_PREDICT_NORMAL (0)
#define C_API_PREDICT_RAW_SCORE (1)
#define C_API_PREDICT_LEAF_INDEX (2)
/*! /*!
* \brief get string message of the last error * \brief get string message of the last error
......
...@@ -45,7 +45,7 @@ public: ...@@ -45,7 +45,7 @@ public:
* \brief Training logic * \brief Training logic
* \param gradient nullptr for using default objective, otherwise use self-defined boosting * \param gradient nullptr for using default objective, otherwise use self-defined boosting
* \param hessian nullptr for using default objective, otherwise use self-defined boosting * \param hessian nullptr for using default objective, otherwise use self-defined boosting
* \param is_eval true if need evalulation or early stop * \param is_eval true if need evaluation or early stop
* \return True if meet early stopping or cannot boosting * \return True if meet early stopping or cannot boosting
*/ */
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override; bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;
......
...@@ -106,12 +106,12 @@ public: ...@@ -106,12 +106,12 @@ public:
if (predictor_ != nullptr) { delete predictor_; } if (predictor_ != nullptr) { delete predictor_; }
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
if (predict_type == 2) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true; is_predict_leaf = true;
} else if (predict_type == 1) { } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = false;
} else {
is_raw_score = true; is_raw_score = true;
} else {
is_raw_score = false;
} }
predictor_ = new Predictor(boosting_, is_raw_score, is_predict_leaf); predictor_ = new Predictor(boosting_, is_raw_score, is_predict_leaf);
} }
...@@ -362,9 +362,9 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, ...@@ -362,9 +362,9 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
int type) { int type) {
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
bool is_success = false; bool is_success = false;
if (type == dtype_float32) { if (type == C_API_DTYPE_FLOAT32) {
is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element)); is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
} else if (type == dtype_int32) { } else if (type == C_API_DTYPE_INT32) {
is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element)); is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
} }
if (is_success) { return 0; } if (is_success) { return 0; }
...@@ -378,10 +378,10 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle, ...@@ -378,10 +378,10 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
int* out_type) { int* out_type) {
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) { if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
*out_type = dtype_float32; *out_type = C_API_DTYPE_FLOAT32;
return 0; return 0;
} else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) { } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
*out_type = dtype_int32; *out_type = C_API_DTYPE_INT32;
return 0; return 0;
} }
return -1; return -1;
...@@ -582,7 +582,7 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -582,7 +582,7 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == dtype_float32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row](int row_idx) {
...@@ -604,7 +604,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -604,7 +604,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} }
} else if (data_type == dtype_float64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row](int row_idx) {
...@@ -634,61 +634,27 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -634,61 +634,27 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
std::function<std::vector<std::pair<int, double>>(int row_idx)> std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == dtype_float32) { auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
const float* data_ptr = reinterpret_cast<const float*>(data); if (inner_function != nullptr) {
if (is_row_major) { return [inner_function](int row_idx) {
return [data_ptr, num_col, num_row](int row_idx) { auto raw_values = inner_function(row_idx);
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
} else if (data_type == dtype_float64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
auto tmp_ptr = data_ptr + num_col * row_idx; for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
for (int i = 0; i < num_col; ++i) { if (std::fabs(raw_values[i]) > 1e-15) {
ret.emplace_back(i, static_cast<double>(*(tmp_ptr + i))); ret.emplace_back(i, raw_values[i]);
} }
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(data_ptr + num_row * i + row_idx)));
} }
return ret; return ret;
}; };
} }
} else {
Log::Fatal("unknown data type in RowPairFunctionFromDenseMatric");
}
return nullptr; return nullptr;
} }
std::function<std::vector<std::pair<int, double>>(int idx)> std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) { RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
if (data_type == dtype_float32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == dtype_int32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr); CHECK(idx + 1 < nindptr);
...@@ -701,7 +667,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -701,7 +667,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
} }
return ret; return ret;
}; };
} else if (indptr_type == dtype_int64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr); CHECK(idx + 1 < nindptr);
...@@ -717,9 +683,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -717,9 +683,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
} else { } else {
Log::Fatal("unknown data type in RowFunctionFromCSR"); Log::Fatal("unknown data type in RowFunctionFromCSR");
} }
} else if (data_type == dtype_float64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == dtype_int32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr); CHECK(idx + 1 < nindptr);
...@@ -732,7 +698,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -732,7 +698,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
} }
return ret; return ret;
}; };
} else if (indptr_type == dtype_int64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr); CHECK(idx + 1 < nindptr);
...@@ -756,9 +722,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -756,9 +722,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
std::function<std::vector<std::pair<int, double>>(int idx)> std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) { ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
if (data_type == dtype_float32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (col_ptr_type == dtype_int32) { if (col_ptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr); CHECK(idx + 1 < ncol_ptr);
...@@ -771,7 +737,7 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi ...@@ -771,7 +737,7 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi
} }
return ret; return ret;
}; };
} else if (col_ptr_type == dtype_int64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr); CHECK(idx + 1 < ncol_ptr);
...@@ -787,9 +753,9 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi ...@@ -787,9 +753,9 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi
} else { } else {
Log::Fatal("unknown data type in ColumnFunctionFromCSC"); Log::Fatal("unknown data type in ColumnFunctionFromCSC");
} }
} else if (data_type == dtype_float64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (col_ptr_type == dtype_int32) { if (col_ptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr); CHECK(idx + 1 < ncol_ptr);
...@@ -802,7 +768,7 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi ...@@ -802,7 +768,7 @@ ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indi
} }
return ret; return ret;
}; };
} else if (col_ptr_type == dtype_int64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr); CHECK(idx + 1 < ncol_ptr);
......
...@@ -7,7 +7,10 @@ import numpy as np ...@@ -7,7 +7,10 @@ import numpy as np
from scipy import sparse from scipy import sparse
def LoadDll(): def LoadDll():
if os.name == 'nt':
lib_path = '../../windows/x64/DLL/lib_lightgbm.dll' lib_path = '../../windows/x64/DLL/lib_lightgbm.dll'
else:
lib_path = '../../lib_lightgbm.so'
lib = ctypes.cdll.LoadLibrary(lib_path) lib = ctypes.cdll.LoadLibrary(lib_path)
return lib return lib
...@@ -23,7 +26,7 @@ def c_array(ctype, values): ...@@ -23,7 +26,7 @@ def c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def c_str(string): def c_str(string):
return ctypes.c_char_p(string.encode('utf-8')) return ctypes.c_char_p(string.encode('ascii'))
def test_load_from_file(filename, reference): def test_load_from_file(filename, reference):
ref = None ref = None
...@@ -37,7 +40,7 @@ def test_load_from_file(filename, reference): ...@@ -37,7 +40,7 @@ def test_load_from_file(filename, reference):
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) ) LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long() num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) ) LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
print '#data:%d #feature:%d' %(num_data.value, num_feature.value) print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) )
return handle return handle
def test_save_to_binary(handle, filename): def test_save_to_binary(handle, filename):
...@@ -50,7 +53,7 @@ def test_load_from_binary(filename): ...@@ -50,7 +53,7 @@ def test_load_from_binary(filename):
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) ) LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long() num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) ) LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
print '#data:%d #feature:%d' %(num_data.value, num_feature.value) print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) )
return handle return handle
def test_load_from_csr(filename, reference): def test_load_from_csr(filename, reference):
...@@ -77,7 +80,7 @@ def test_load_from_csr(filename, reference): ...@@ -77,7 +80,7 @@ def test_load_from_csr(filename, reference):
len(csr.indptr), len(csr.indptr),
len(csr.data), len(csr.data),
csr.shape[1], csr.shape[1],
ctypes.c_char_p('max_bin=15'), c_str('max_bin=15'),
ref, ref,
ctypes.byref(handle) ) ctypes.byref(handle) )
num_data = ctypes.c_long() num_data = ctypes.c_long()
...@@ -85,7 +88,7 @@ def test_load_from_csr(filename, reference): ...@@ -85,7 +88,7 @@ def test_load_from_csr(filename, reference):
num_feature = ctypes.c_long() num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) ) LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0) LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
print '#data:%d #feature:%d' %(num_data.value, num_feature.value) print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) )
return handle return handle
def test_load_from_csc(filename, reference): def test_load_from_csc(filename, reference):
...@@ -112,7 +115,7 @@ def test_load_from_csc(filename, reference): ...@@ -112,7 +115,7 @@ def test_load_from_csc(filename, reference):
len(csr.indptr), len(csr.indptr),
len(csr.data), len(csr.data),
csr.shape[0], csr.shape[0],
ctypes.c_char_p('max_bin=15'), c_str('max_bin=15'),
ref, ref,
ctypes.byref(handle) ) ctypes.byref(handle) )
num_data = ctypes.c_long() num_data = ctypes.c_long()
...@@ -120,7 +123,7 @@ def test_load_from_csc(filename, reference): ...@@ -120,7 +123,7 @@ def test_load_from_csc(filename, reference):
num_feature = ctypes.c_long() num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) ) LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0) LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
print '#data:%d #feature:%d' %(num_data.value, num_feature.value) print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) )
return handle return handle
def test_load_from_mat(filename, reference): def test_load_from_mat(filename, reference):
...@@ -144,7 +147,7 @@ def test_load_from_mat(filename, reference): ...@@ -144,7 +147,7 @@ def test_load_from_mat(filename, reference):
mat.shape[0], mat.shape[0],
mat.shape[1], mat.shape[1],
1, 1,
ctypes.c_char_p('max_bin=15'), c_str('max_bin=15'),
ref, ref,
ctypes.byref(handle) ) ctypes.byref(handle) )
num_data = ctypes.c_long() num_data = ctypes.c_long()
...@@ -152,7 +155,7 @@ def test_load_from_mat(filename, reference): ...@@ -152,7 +155,7 @@ def test_load_from_mat(filename, reference):
num_feature = ctypes.c_long() num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) ) LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0) LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
print '#data:%d #feature:%d' %(num_data.value, num_feature.value) print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) )
return handle return handle
def test_free_dataset(handle): def test_free_dataset(handle):
LIB.LGBM_DatasetFree(handle) LIB.LGBM_DatasetFree(handle)
...@@ -175,14 +178,14 @@ def test_booster(): ...@@ -175,14 +178,14 @@ def test_booster():
name = [c_str('test')] name = [c_str('test')]
booster = ctypes.c_void_p() booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name), LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name),
len(test), "app=binary metric=auc num_leaves=31 verbose=0", ctypes.byref(booster)) len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster))
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
for i in xrange(100): for i in range(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished)) LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
result = np.array([0.0], dtype=np.float32) result = np.array([0.0], dtype=np.float32)
out_len = ctypes.c_ulong(0) out_len = ctypes.c_ulong(0)
LIB.LGBM_BoosterEval(booster, 1, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float))) LIB.LGBM_BoosterEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
print '%d Iteration test AUC %f' %(i, result[0]) print ('%d Iteration test AUC %f' %(i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster) LIB.LGBM_BoosterFree(booster)
test_free_dataset(train) test_free_dataset(train)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment