Commit 9d375069 authored by Guolin Ke's avatar Guolin Ke
Browse files

add LGBM_DatasetCreateFromSampledMat api.

parent 9f69165b
...@@ -52,9 +52,29 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -52,9 +52,29 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out); DatasetHandle* out);
/*!
* \brief create a empty dataset by sampling matrix, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param data pointer to the data space
* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64
* \param num_sample_row number of rows
* \param ncol number columns
* \param num_total_row number of total rows
* \param parameters additional parameters
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out);
/*! /*!
* \brief create a empty dataset by sampling csc data, if num_sample_row == num_total_row, will construct this dataset. * \brief create a empty dataset by sampling CSR data, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param indptr pointer to row headers * \param indptr pointer to row headers
* \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64 * \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64
* \param indices findex * \param indices findex
......
...@@ -310,6 +310,40 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -310,6 +310,40 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out) {
if (num_sample_row == num_total_row) {
return LGBM_DatasetCreateFromMat(data, data_type, num_total_row, ncol, 1, parameters, nullptr, out);
} else {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
auto get_row_fun = RowFunctionFromDenseMatric(data, num_sample_row, ncol, data_type, 1);
std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol);
for (int i = 0; i < num_sample_row; ++i) {
auto row = get_row_fun(i);
for (size_t idx = 0; idx < row.size(); ++idx) {
if (std::fabs(row[idx]) > kEpsilon) {
sample_values[idx].emplace_back(row[idx]);
sample_idx[idx].emplace_back(i);
}
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row,
static_cast<data_size_t>(num_total_row));
API_END();
}
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
...@@ -547,7 +581,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -547,7 +581,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values(ncol_ptr - 1); std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1); std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i); CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
for (int j = 0; j < sample_cnt; j++) { for (int j = 0; j < sample_cnt; j++) {
...@@ -1096,7 +1130,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1096,7 +1130,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
if (data_type == C_API_DTYPE_FLOAT32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx; auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1105,7 +1139,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1105,7 +1139,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
...@@ -1116,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1116,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx; auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1125,7 +1159,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1125,7 +1159,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row](int row_idx) { return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
...@@ -1141,7 +1175,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)> ...@@ -1141,7 +1175,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major); auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
if (inner_function != nullptr) { if (inner_function != nullptr) {
return [inner_function](int row_idx) { return [inner_function] (int row_idx) {
auto raw_values = inner_function(row_idx); auto raw_values = inner_function(row_idx);
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) { for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
...@@ -1161,7 +1195,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1161,7 +1195,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1172,7 +1206,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1172,7 +1206,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1186,7 +1220,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1186,7 +1220,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1197,7 +1231,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1197,7 +1231,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1220,7 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1220,7 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1233,7 +1267,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1233,7 +1267,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1249,7 +1283,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1249,7 +1283,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1262,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1262,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment