Commit 044f79aa authored by Guolin Ke's avatar Guolin Ke
Browse files

support int64_t for CSR/CSR's indptr

parent 31b19afc
......@@ -26,6 +26,11 @@
typedef void* DatesetHandle;
typedef void* BoosterHandle;
#define dtype_float32 (0)
#define dtype_float64 (1)
#define dtype_int32 (2)
#define dtype_int64 (3)
/*!
* \brief get string message of the last error
* all function in this file will return 0 when success
......@@ -62,9 +67,10 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
/*!
* \brief create a dataset from CSR format
* \param indptr pointer to row headers
* \param indptr_type 0:int_32 1:int_64
* \param indices findex
* \param data fvalue
* \param float_type 0 for float_32 1 for float_64
* \param data_type 0 for float_32 1 for float_64
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data
......@@ -73,13 +79,14 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
......@@ -87,9 +94,10 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
/*!
* \brief create a dataset from CSC format
* \param col_ptr pointer to col headers
* \param col_ptr_type 0:int_32 1:int_64
* \param indices findex
* \param data fvalue
* \param float_type 0 for float_32 1 for float_64
* \param data_type 0 for float_32 1 for float_64
* \param ncol_ptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
......@@ -98,13 +106,14 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int float_type,
uint64_t ncol_ptr,
uint64_t nelem,
uint64_t num_row,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
......@@ -112,7 +121,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
/*!
* \brief create dataset from dense matrix
* \param data pointer to the data space
* \param float_type 0 for float_32 1 for float_64
* \param data_type 0 for float_32 1 for float_64
* \param nrow number of rows
* \param ncol number columns
* \param is_row_major 1 for row major, 0 for column major
......@@ -122,7 +131,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromMat(const void* data,
int float_type,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
......@@ -151,13 +160,13 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
* \param field_name field name, can be label, weight, group
* \param field_data pointer to vector
* \param num_element number of element in field_data
* \param type float_32:0, uint32_t:1
* \param type float_32:0, int32_t:1
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const char* field_name,
const void* field_data,
uint64_t num_element,
int64_t num_element,
int type);
/*!
......@@ -166,12 +175,12 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
* \param field_name field name
* \param out_len used to set result length
* \param out_ptr pointer to the result
* \param out_type float_32:0, uint32_t:1
* \param out_type float_32:0, int32_t:1
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetGetField(DatesetHandle handle,
const char* field_name,
uint64_t* out_len,
int64_t* out_len,
const void** out_ptr,
int* out_type);
......@@ -182,7 +191,7 @@ DllExport int LGBM_DatasetGetField(DatesetHandle handle,
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
uint64_t* out);
int64_t* out);
/*!
* \brief get number of features
......@@ -191,7 +200,7 @@ DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
uint64_t* out);
int64_t* out);
// --- start Booster interfaces
......@@ -261,7 +270,7 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
*/
DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data,
uint64_t* out_len,
int64_t* out_len,
float* out_results);
/*!
......@@ -272,7 +281,7 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
uint64_t* out_len,
int64_t* out_len,
const float** out_result);
/*!
......@@ -286,16 +295,17 @@ this can be used to support customized eval function
*/
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data,
uint64_t* out_len,
int64_t* out_len,
float* out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param indptr pointer to row headers
* \param indptr_type 0:int_32 1:int_64
* \param indices findex
* \param data fvalue
* \param float_type 0:float_32 1:float64
* \param data_type 0:float_32 1:float64
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data
......@@ -308,22 +318,23 @@ DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const int32_t* indptr,
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
int predict_type,
uint64_t n_used_trees,
int64_t n_used_trees,
double* out_result);
/*!
* \brief make prediction for an new data set
* \param handle handle
* \param data pointer to the data space
* \param float_type 0:float_32 1:float64
* \param data_type 0:float_32 1:float64
* \param nrow number of rows
* \param ncol number columns
* \param is_row_major 1 for row major, 0 for column major
......@@ -337,12 +348,12 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
*/
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data,
int float_type,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
int predict_type,
uint64_t n_used_trees,
int64_t n_used_trees,
double* out_result);
/*!
......@@ -356,4 +367,23 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_used_model,
const char* filename);
std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const void* data, int data_type, int64_t nindptr, int64_t nelem);
std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices,
const void* data, int data_type, int64_t ncol_ptr, int64_t nelem);
std::vector<double>
SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices);
#endif // LIGHTGBM_C_API_H_
......@@ -277,9 +277,13 @@ public:
void FinishLoad();
void SetField(const char* field_name, const void* field_data, data_size_t num_element, int type);
bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element);
void GetField(const char* field_name, uint64_t* out_len, const void** out_ptr, int* out_type);
bool SetIntField(const char* field_name, const int* field_data, data_size_t num_element);
bool GetFloatField(const char* field_name, int64_t* out_len, const float** out_ptr);
bool GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr);
/*!
* \brief Save current dataset into binary file, will save to "filename.bin"
......
......@@ -393,181 +393,6 @@ inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t s
}
inline std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int float_type, int is_row_major) {
if (float_type == 0) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
} else {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
}
}
inline std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int float_type, int is_row_major) {
if (float_type == 0) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
} else {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
}
}
inline std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const int32_t* indptr, const int32_t* indices, const void* data, int float_type, uint64_t nindptr, uint64_t nelem) {
if (float_type == 0) {
const float* data_ptr = reinterpret_cast<const float*>(data);
return [indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int32_t start = indptr[idx];
int32_t end = indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int32_t i = start; i <= end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else {
const double* data_ptr = reinterpret_cast<const double*>(data);
return [indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int32_t start = indptr[idx];
int32_t end = indptr[idx + 1];
CHECK(start >= 0 && end <= nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
}
}
inline std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const int32_t* col_ptr, const int32_t* indices, const void* data, int float_type, uint64_t ncol_ptr, uint64_t nelem) {
if (float_type == 0) {
const float* data_ptr = reinterpret_cast<const float*>(data);
return [col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int32_t start = col_ptr[idx];
int32_t end = col_ptr[idx + 1];
CHECK(start >= 0 && end <= nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else {
const double* data_ptr = reinterpret_cast<const double*>(data);
return [col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int32_t start = col_ptr[idx];
int32_t end = col_ptr[idx + 1];
CHECK(start >= 0 && end <= nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
}
}
inline std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices) {
size_t j = 0;
std::vector<double> ret;
for (auto row_idx : indices) {
while (j < data.size() && data[j].first < row_idx) {
++j;
}
if (j < data.size() && data[j].first == row_idx) {
ret.push_back(data[j].second);
} else {
ret.push_back(0);
}
}
return ret;
}
} // namespace Common
......
This diff is collapsed.
......@@ -183,7 +183,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
}
DllExport int LGBM_CreateDatasetFromMat(const void* data,
int float_type,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
......@@ -195,7 +195,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_row_fun = Common::RowFunctionFromDenseMatric(data, nrow, ncol, float_type, is_row_major);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
if (reference == nullptr) {
// sample data first
Random rand(config.io_config.data_random_seed);
......@@ -226,13 +226,14 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
return 0;
}
DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t num_col,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
......@@ -241,7 +242,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_row_fun = Common::RowFunctionFromCSR(indptr, indices, data, float_type, nindptr, nelem);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (reference == nullptr) {
// sample data first
......@@ -269,7 +270,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
sample_values[inner_data.first][i] = inner_data.second;
}
}
CHECK(num_col >= sample_values.size());
CHECK(num_col >= static_cast<int>(sample_values.size()));
ret = loader.CostructFromSampleData(sample_values, nrow);
} else {
ret = new Dataset(nrow, config.io_config.num_class);
......@@ -289,13 +290,14 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
}
DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int float_type,
uint64_t ncol_ptr,
uint64_t nelem,
uint64_t num_row,
int data_type,
int64_t ncol_ptr,
int64_t nelem,
int64_t num_row,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
......@@ -303,7 +305,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_col_fun = Common::ColumnFunctionFromCSC(col_ptr, indices, data, float_type, ncol_ptr, nelem);
auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
int32_t nrow = static_cast<int32_t>(num_row);
if (reference == nullptr) {
Log::Warning("Construct from CSC format is not efficient");
......@@ -315,7 +317,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
#pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
auto cur_col = get_col_fun(i);
sample_values[i] = Common::SampleFromOneColumn(cur_col, sample_indices);
sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
}
ret = loader.CostructFromSampleData(sample_values, nrow);
} else {
......@@ -350,34 +352,44 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
DllExport int LGBM_DatasetSetField(DatesetHandle handle,
const char* field_name,
const void* field_data,
uint64_t num_element,
int64_t num_element,
int type) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SetField(field_name, field_data, static_cast<int32_t>(num_element), type);
return 0;
bool is_success = false;
if (type == dtype_float32) {
is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
} else if (type == dtype_int32) {
is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
}
if (is_success) { return 0; }
return -1;
}
DllExport int LGBM_DatasetGetField(DatesetHandle handle,
const char* field_name,
uint64_t* out_len,
int64_t* out_len,
const void** out_ptr,
int* out_type) {
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->GetField(field_name, out_len, out_ptr, out_type);
return 0;
if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
*out_type = dtype_float32;
return 0;
} else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
*out_type = dtype_int32;
return 0;
}
return -1;
}
DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
uint64_t* out) {
int64_t* out) {
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_data();
return 0;
}
DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
uint64_t* out) {
int64_t* out) {
auto dataset = reinterpret_cast<Dataset*>(handle);
*out = dataset->num_total_features();
return 0;
......@@ -442,13 +454,13 @@ DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
DllExport int LGBM_BoosterEval(BoosterHandle handle,
int data,
uint64_t* out_len,
int64_t* out_len,
float* out_results) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
auto result_buf = boosting->GetEvalAt(data);
*out_len = static_cast<uint64_t>(result_buf.size());
*out_len = static_cast<int64_t>(result_buf.size());
for (size_t i = 0; i < result_buf.size(); ++i) {
(out_results)[i] = static_cast<float>(result_buf[i]);
}
......@@ -456,47 +468,48 @@ DllExport int LGBM_BoosterEval(BoosterHandle handle,
}
DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
uint64_t* out_len,
int64_t* out_len,
const float** out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0;
*out_result = boosting->GetTrainingScore(&len);
*out_len = static_cast<uint64_t>(len);
*out_len = static_cast<int64_t>(len);
return 0;
}
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
int data,
uint64_t* out_len,
int64_t* out_len,
float* out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto boosting = ref_booster->GetBoosting();
int len = 0;
boosting->GetPredictAt(data, out_result, &len);
*out_len = static_cast<uint64_t>(len);
*out_len = static_cast<int64_t>(len);
return 0;
}
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const int32_t* indptr,
const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int float_type,
uint64_t nindptr,
uint64_t nelem,
uint64_t,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t,
int predict_type,
uint64_t n_used_trees,
int64_t n_used_trees,
double* out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
auto get_row_fun = Common::RowFunctionFromCSR(indptr, indices, data, float_type, nindptr, nelem);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int num_class = ref_booster->NumberOfClasses();
int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
......@@ -512,18 +525,18 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data,
int float_type,
int data_type,
int32_t nrow,
int32_t ncol,
int is_row_major,
int predict_type,
uint64_t n_used_trees,
int64_t n_used_trees,
double* out_result) {
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
auto get_row_fun = Common::RowPairFunctionFromDenseMatric(data, nrow, ncol, float_type, is_row_major);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
int num_class = ref_booster->NumberOfClasses();
#pragma omp parallel for schedule(guided)
for (int i = 0; i < nrow; ++i) {
......@@ -544,3 +557,261 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
ref_booster->SaveModelToFile(num_used_model, filename);
return 0;
}
std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == dtype_float32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
} else if (data_type == dtype_float64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<double> ret;
for (int i = 0; i < num_col; ++i) {
ret.push_back(static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
} else {
Log::Fatal("unknown data type in RowFunctionFromDenseMatric");
}
}
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
if (data_type == dtype_float32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
} else if (data_type == dtype_float64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(tmp_ptr + i)));
}
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
CHECK(row_idx < num_row);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < num_col; ++i) {
ret.emplace_back(i, static_cast<double>(*(data_ptr + num_row * i + row_idx)));
}
return ret;
};
}
} else {
Log::Fatal("unknown data type in RowPairFunctionFromDenseMatric");
}
}
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
if (data_type == dtype_float32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == dtype_int32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int64_t i = start; i <= end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else if (indptr_type == dtype_int64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int64_t i = start; i <= end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else {
Log::Fatal("unknown data type in RowFunctionFromCSR");
}
} else if (data_type == dtype_float64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == dtype_int32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int64_t i = start; i <= end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else if (indptr_type == dtype_int64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
CHECK(idx + 1 < nindptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int64_t i = start; i <= end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else {
Log::Fatal("unknown data type in RowFunctionFromCSR");
}
} else {
Log::Fatal("unknown data type in RowFunctionFromCSR");
}
}
std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
if (data_type == dtype_float32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (col_ptr_type == dtype_int32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_col_ptr[idx];
int64_t end = ptr_col_ptr[idx + 1];
CHECK(start >= 0 && end <= nelem);
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else if (col_ptr_type == dtype_int64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_col_ptr[idx];
int64_t end = ptr_col_ptr[idx + 1];
CHECK(start >= 0 && end <= nelem);
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else {
Log::Fatal("unknown data type in ColumnFunctionFromCSC");
}
} else if (data_type == dtype_float64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (col_ptr_type == dtype_int32) {
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_col_ptr[idx];
int64_t end = ptr_col_ptr[idx + 1];
CHECK(start >= 0 && end <= nelem);
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else if (col_ptr_type == dtype_int64) {
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_col_ptr[idx];
int64_t end = ptr_col_ptr[idx + 1];
CHECK(start >= 0 && end <= nelem);
for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]);
}
return ret;
};
} else {
Log::Fatal("unknown data type in ColumnFunctionFromCSC");
}
} else {
Log::Fatal("unknown data type in ColumnFunctionFromCSC");
}
}
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices) {
size_t j = 0;
std::vector<double> ret;
for (auto row_idx : indices) {
while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
++j;
}
if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
ret.push_back(data[j].second);
} else {
ret.push_back(0);
}
}
return ret;
}
\ No newline at end of file
......@@ -64,57 +64,60 @@ std::vector<const BinMapper*> Dataset::GetBinMappers() const {
return ret;
}
void Dataset::SetField(const char* field_name, const void* field_data, data_size_t num_element, int type) {
bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
if (type != 0) {
Log::Fatal("type of label should be float");
}
metadata_.SetLabel(reinterpret_cast<const float*>(field_data), num_element);
metadata_.SetLabel(field_data, num_element);
} else if (name == std::string("weight") || name == std::string("weights")) {
if (type != 0) {
Log::Fatal("type of weights should be float");
}
metadata_.SetWeights(reinterpret_cast<const float*>(field_data), num_element);
metadata_.SetWeights(field_data, num_element);
} else if (name == std::string("init_score")) {
if (type != 0) {
Log::Fatal("type of init_score should be float");
}
metadata_.SetInitScore(reinterpret_cast<const float*>(field_data), num_element);
} else if (name == std::string("query") || name == std::string("group")) {
if (type != 1) {
Log::Fatal("type of init_score should be int");
}
metadata_.SetQueryBoundaries(reinterpret_cast<const data_size_t*>(field_data), num_element);
metadata_.SetInitScore(field_data, num_element);
} else {
Log::Fatal("unknow field name: %s", field_name);
return false;
}
return true;
}
void Dataset::GetField(const char* field_name, uint64_t* out_len, const void** out_ptr, int* out_type) {
bool Dataset::SetIntField(const char* field_name, const int* field_data, data_size_t num_element) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQueryBoundaries(field_data, num_element);
} else {
return false;
}
return true;
}
bool Dataset::GetFloatField(const char* field_name, int64_t* out_len, const float** out_ptr) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
*out_ptr = metadata_.label();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("weight") || name == std::string("weights")) {
*out_ptr = metadata_.weights();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("init_score")) {
*out_ptr = metadata_.init_score();
*out_len = num_data_;
*out_type = 0;
} else if (name == std::string("query") || name == std::string("group")) {
} else {
return false;
}
return true;
}
bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** out_ptr) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) {
*out_ptr = metadata_.query_boundaries();
*out_len = num_data_;
*out_type = 1;
} else {
Log::Fatal("unknow field name: %s", field_name);
return false;
}
return true;
}
void Dataset::SaveBinaryFile(const char* bin_filename) {
......
......@@ -511,7 +511,10 @@ std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filenam
}
std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
const size_t sample_cnt = static_cast<size_t>(data.size() < io_config_.bin_construct_sample_cnt ? data.size() : io_config_.bin_construct_sample_cnt);
size_t sample_cnt = static_cast<size_t>(io_config_.bin_construct_sample_cnt);
if (sample_cnt > data.size()) {
sample_cnt = data.size();
}
std::vector<size_t> sample_indices = random_.Sample(data.size(), sample_cnt);
std::vector<std::string> out;
for (size_t i = 0; i < sample_indices.size(); ++i) {
......
from __future__ import absolute_import
import sys
import os
import ctypes
import collections
import re
import numpy as np
from scipy import sparse
def _load_lib():
"""Load xgboost Library."""
lib_path = './windows/x64/DLL/lib_lightgbm.dll'
if len(lib_path) == 0:
return None
lib = ctypes.cdll.LoadLibrary(lib_path)
return lib
LIB = _load_lib()
LIB.LGBM_GetLastError.restype = ctypes.c_char_p
def test_load_from_file():
handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromFile(ctypes.c_char_p('./examples/binary_classification/binary.train'),
ctypes.c_char_p('max_bin=15'), ctypes.c_void_p(None), ctypes.byref(handle))
num_data = ctypes.c_ulong()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
print num_data
num_feature = ctypes.c_ulong()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
return handle
def c_array(ctype, values):
"""Convert a python string to c array."""
return (ctype * len(values))(*values)
def c_str(string):
"""Convert a python string to cstring."""
return ctypes.c_char_p(string.encode('utf-8'))
def test_load_from_matric():
data = []
inp = open('./examples/binary_classification/binary.train', 'r')
for line in inp.readlines():
data.append( [float(x) for x in line.split('\t')[1:]] )
inp.close()
mat = np.array(data)
print mat.shape
data = np.array(mat.reshape(mat.size), copy=False)
handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), 1,
mat.shape[0],
mat.shape[1], 1,
ctypes.c_char_p('max_bin=15 is_sparse=false'), None, ctypes.byref(handle) )
LIB.LGBM_DatasetFree(ctypes.byref(handle))
# num_data = ctypes.c_ulong()
# LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
# print num_data
# num_feature = ctypes.c_ulong()
# LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
# print num_feature
return handle
def test_load_from_csr(filename, reference):
data = []
label = []
inp = open(filename, 'r')
for line in inp.readlines():
data.append( [float(x) for x in line.split('\t')[1:]] )
label.append( float(line.split('\t')[0]) )
inp.close()
mat = np.array(data)
label = np.array(label, dtype=np.float32)
print mat.shape
csr = sparse.csr_matrix(mat)
handle = ctypes.c_void_p()
ref = None
if reference != None:
ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromCSR(c_array(ctypes.c_int, csr.indptr),
c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
1, len(csr.indptr), len(csr.data),
csr.shape[1], ctypes.c_char_p('max_bin=15'), ref, ctypes.byref(handle) )
num_data = ctypes.c_ulong()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
print num_data
num_feature = ctypes.c_ulong()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
LIB.LGBM_DatasetSetField(handle, c_str('label'), c_array(ctypes.c_float, label), len(label), 0)
return handle
train = test_load_from_csr('./examples/binary_classification/binary.train', None)
test = [test_load_from_csr('./examples/binary_classification/binary.test', train)]
name = [c_str('test')]
booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), c_array(ctypes.c_char_p, name), 1, "app=binary metric=auc num_leaves=31", ctypes.byref(booster))
is_finished = ctypes.c_int(0)
for i in xrange(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
result = np.array([0.0], dtype=np.float32)
out_len = ctypes.c_ulong(0)
LIB.LGBM_BoosterEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)))
print result
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
booster2 = ctypes.c_void_p()
LIB.LGBM_BoosterLoadFromModelfile(c_str('model.txt'), ctypes.byref(booster2))
print type(len([0,0]))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment