Commit 4a451b89 authored by Guolin Ke's avatar Guolin Ke
Browse files

support predict from file. add more tests

parent c6512e01
......@@ -143,7 +143,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
* \brief free space for dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetFree(DatesetHandle* handle);
DllExport int LGBM_DatasetFree(DatesetHandle handle);
/*!
* \brief save dateset to binary file
......@@ -298,6 +298,26 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int64_t* out_len,
float* out_result);
/*!
* \brief make prediction for file
* \param handle handle
* \param predict_type
* 0:raw score
* 1:with transform(if needed)
* 2:leaf index
* \param n_used_trees number of used tree
* \param data_has_header data file has header or not
* \param data_filename filename of data file
* \param result_filename filename of result file
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
int predict_type,
int64_t n_used_trees,
int data_has_header,
const char* data_filename,
const char* result_filename);
/*!
* \brief make prediction for an new data set
* \param handle handle
......
......@@ -227,14 +227,16 @@ public:
bool predict_leaf_index = false;
IOConfig io_config;
BoostingType boosting_type = BoostingType::kGBDT;
BoostingConfig* boosting_config;
BoostingConfig* boosting_config = nullptr;
std::string objective_type = "regression";
ObjectiveConfig objective_config;
std::vector<std::string> metric_types;
MetricConfig metric_config;
~OverallConfig() {
if (boosting_config != nullptr) {
delete boosting_config;
}
}
void Set(const std::unordered_map<std::string, std::string>& params) override;
void LoadFromString(const char* str);
private:
......
......@@ -23,7 +23,9 @@ namespace LightGBM {
class Booster {
public:
explicit Booster(const char* filename):
boosting_(Boosting::CreateBoosting(filename)), predictor_(nullptr) {
boosting_(Boosting::CreateBoosting(filename)),
objective_fun_(nullptr),
predictor_(nullptr) {
}
Booster(const Dataset* train_data,
......@@ -118,6 +120,10 @@ public:
return predictor_->GetPredictFunction()(features);
}
void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
predictor_->Predict(data_filename, result_filename, data_has_header);
}
void SaveModelToFile(int num_used_model, const char* filename) {
boosting_->SaveModelToFile(num_used_model, true, filename);
}
......@@ -164,6 +170,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
loader.SetHeader(filename);
if (reference == nullptr) {
*out = loader.LoadFromFile(filename);
} else {
......@@ -335,8 +342,8 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
return 0;
}
DllExport int LGBM_DatasetFree(DatesetHandle* handle) {
auto dataset = reinterpret_cast<Dataset*>(*handle);
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
auto dataset = reinterpret_cast<Dataset*>(handle);
delete dataset;
return 0;
}
......@@ -492,6 +499,20 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
return 0;
}
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
int predict_type,
int64_t n_used_trees,
int data_has_header,
const char* data_filename,
const char* result_filename) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
bool bool_data_has_header = data_has_header > 0 ? true : false;
ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
return 0;
}
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const void* indptr,
int indptr_type,
......
......@@ -123,9 +123,9 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int**
void Dataset::SaveBinaryFile(const char* bin_filename) {
if (!is_loading_from_binfile_) {
std::string bin_filename_str(data_filename_);
// if not pass a filename, just append ".bin" of original file
if (bin_filename == nullptr || bin_filename[0] == '\0') {
std::string bin_filename_str(data_filename_);
bin_filename_str.append(".bin");
bin_filename = bin_filename_str.c_str();
}
......@@ -138,8 +138,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
if (file == NULL) {
Log::Fatal("Cannot write binary data to %s ", bin_filename);
}
Log::Info("Saving data to binary file %s", data_filename_);
Log::Info("Saving data to binary file %s", bin_filename);
// get size of header
size_t size_of_header = sizeof(num_data_) + sizeof(num_class_) + sizeof(num_features_) + sizeof(num_total_features_)
......
......@@ -8,19 +8,51 @@ from scipy import sparse
def LoadDll():
lib_path = '../../windows/x64/DLL/lib_lightgbm.dll'
if len(lib_path) == 0:
return None
lib = ctypes.cdll.LoadLibrary(lib_path)
return lib
LIB = LoadDll()
dtype_float32 = 0
dtype_float64 = 1
dtype_int32 = 2
dtype_int64 = 3
def c_array(ctype, values):
return (ctype * len(values))(*values)
def c_str(string):
return ctypes.c_char_p(string.encode('utf-8'))
def test_load_from_file(filename, reference):
ref = None
if reference != None:
ref = ctypes.byref(reference)
handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromFile(c_str(filename),
c_str('max_bin=15'),
ref, ctypes.byref(handle) )
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
print '#data:%d #feature:%d' %(num_data.value, num_feature.value)
return handle
def test_save_to_binary(handle, filename):
LIB.LGBM_DatasetSaveBinary(handle, c_str(filename))
def test_load_from_binary(filename):
handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromBinaryFile(c_str(filename), ctypes.byref(handle) )
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
print '#data:%d #feature:%d' %(num_data.value, num_feature.value)
return handle
def test_load_from_csr(filename, reference):
data = []
label = []
......@@ -37,35 +69,146 @@ def test_load_from_csr(filename, reference):
if reference != None:
ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromCSR(c_array(ctypes.c_int, csr.indptr), 2,
LIB.LGBM_CreateDatasetFromCSR(c_array(ctypes.c_int, csr.indptr),
dtype_int32,
c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
1, len(csr.indptr), len(csr.data),
csr.shape[1], ctypes.c_char_p('max_bin=15'), ref, ctypes.byref(handle) )
num_data = ctypes.c_ulong()
dtype_float64,
len(csr.indptr),
len(csr.data),
csr.shape[1],
ctypes.c_char_p('max_bin=15'),
ref,
ctypes.byref(handle) )
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_ulong()
num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
print '#data:%d #feature:%d' %(num_data.value, num_feature.value)
return handle
def test_load_from_csc(filename, reference):
data = []
label = []
inp = open(filename, 'r')
for line in inp.readlines():
data.append( [float(x) for x in line.split('\t')[1:]] )
label.append( float(line.split('\t')[0]) )
inp.close()
mat = np.array(data)
label = np.array(label, dtype=np.float32)
csr = sparse.csc_matrix(mat)
handle = ctypes.c_void_p()
ref = None
if reference != None:
ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromCSC(c_array(ctypes.c_int, csr.indptr),
dtype_int32,
c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
dtype_float64,
len(csr.indptr),
len(csr.data),
csr.shape[0],
ctypes.c_char_p('max_bin=15'),
ref,
ctypes.byref(handle) )
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
print '#data:%d #feature:%d' %(num_data.value, num_feature.value)
return handle
def test_load_from_mat(filename, reference):
data = []
label = []
inp = open(filename, 'r')
for line in inp.readlines():
data.append( [float(x) for x in line.split('\t')[1:]] )
label.append( float(line.split('\t')[0]) )
inp.close()
mat = np.array(data)
data = np.array(mat.reshape(mat.size), copy=False)
label = np.array(label, dtype=np.float32)
handle = ctypes.c_void_p()
ref = None
if reference != None:
ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
dtype_float64,
mat.shape[0],
mat.shape[1],
1,
ctypes.c_char_p('max_bin=15'),
ref,
ctypes.byref(handle) )
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
print '#data:%d #feature:%d' %(num_data.value, num_feature.value)
return handle
def test_free_dataset(handle):
LIB.LGBM_DatasetFree(handle)
train = test_load_from_csr('../../examples/binary_classification/binary.train', None)
test = [test_load_from_csr('../../examples/binary_classification/binary.test', train)]
name = [c_str('test')]
booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name),
def test_dataset():
train = test_load_from_file('../../examples/binary_classification/binary.train', None)
test = test_load_from_mat('../../examples/binary_classification/binary.test', train)
test_free_dataset(test)
test = test_load_from_csr('../../examples/binary_classification/binary.test', train)
test_free_dataset(test)
test = test_load_from_csc('../../examples/binary_classification/binary.test', train)
test_free_dataset(test)
test_save_to_binary(train, 'train.binary.bin')
test_free_dataset(train)
train = test_load_from_binary('train.binary.bin')
test_free_dataset(train)
def test_booster():
train = test_load_from_mat('../../examples/binary_classification/binary.train', None)
test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)]
name = [c_str('test')]
booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name),
len(test), "app=binary metric=auc num_leaves=31 verbose=0", ctypes.byref(booster))
is_finished = ctypes.c_int(0)
for i in xrange(100):
is_finished = ctypes.c_int(0)
for i in xrange(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
result = np.array([0.0], dtype=np.float32)
out_len = ctypes.c_ulong(0)
LIB.LGBM_BoosterEval(booster, 1, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
print '%d Iteration test AUC %f' %(i, result[0])
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster)
test_free_dataset(train)
test_free_dataset(test[0])
booster2 = ctypes.c_void_p()
LIB.LGBM_BoosterLoadFromModelfile(c_str('model.txt'), ctypes.byref(booster2))
data = []
inp = open('../../examples/binary_classification/binary.test', 'r')
for line in inp.readlines():
data.append( [float(x) for x in line.split('\t')[1:]] )
inp.close()
mat = np.array(data)
preb = np.zeros(( mat.shape[0],1 ), dtype=np.float64)
data = np.array(mat.reshape(mat.size), copy=False)
LIB.LGBM_BoosterPredictForMat(booster2,
data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
dtype_float64,
mat.shape[0],
mat.shape[1],
1,
1,
50,
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
LIB.LGBM_BoosterPredictForFile(booster2, 1, 50, 0, c_str('../../examples/binary_classification/binary.test'), c_str('preb.txt'))
LIB.LGBM_BoosterFree(booster2)
test_dataset()
test_booster()
booster2 = ctypes.c_void_p()
LIB.LGBM_BoosterLoadFromModelfile(c_str('model.txt'), ctypes.byref(booster2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment