Commit 9d375069 authored by Guolin Ke's avatar Guolin Ke
Browse files

add LGBM_DatasetCreateFromSampledMat api.

parent 9f69165b
......@@ -52,9 +52,29 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
const DatasetHandle reference,
DatasetHandle* out);
/*!
* \brief create a empty dataset by sampling matrix, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param data pointer to the data space
* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64
* \param num_sample_row number of rows
* \param ncol number columns
* \param num_total_row number of total rows
* \param parameters additional parameters
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out);
/*!
* \brief create a empty dataset by sampling csc data, if num_sample_row == num_total_row, will construct this dataset.
* \brief create a empty dataset by sampling CSR data, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param indptr pointer to row headers
* \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64
* \param indices findex
......
......@@ -310,6 +310,40 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out) {
if (num_sample_row == num_total_row) {
return LGBM_DatasetCreateFromMat(data, data_type, num_total_row, ncol, 1, parameters, nullptr, out);
} else {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
auto get_row_fun = RowFunctionFromDenseMatric(data, num_sample_row, ncol, data_type, 1);
std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol);
for (int i = 0; i < num_sample_row; ++i) {
auto row = get_row_fun(i);
for (size_t idx = 0; idx < row.size(); ++idx) {
if (std::fabs(row[idx]) > kEpsilon) {
sample_values[idx].emplace_back(row[idx]);
sample_idx[idx].emplace_back(i);
}
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row,
static_cast<data_size_t>(num_total_row));
API_END();
}
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
......@@ -547,7 +581,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
for (int j = 0; j < sample_cnt; j++) {
......@@ -1096,7 +1130,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
......@@ -1105,7 +1139,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
......@@ -1116,7 +1150,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
} else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + num_col * row_idx;
for (int i = 0; i < num_col; ++i) {
......@@ -1125,7 +1159,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret;
};
} else {
return [data_ptr, num_col, num_row](int row_idx) {
return [data_ptr, num_col, num_row] (int row_idx) {
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
......@@ -1141,7 +1175,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
if (inner_function != nullptr) {
return [inner_function](int row_idx) {
return [inner_function] (int row_idx) {
auto raw_values = inner_function(row_idx);
std::vector<std::pair<int, double>> ret;
for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
......@@ -1161,7 +1195,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1172,7 +1206,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
};
} else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1186,7 +1220,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1197,7 +1231,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
};
} else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
......@@ -1220,7 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......@@ -1233,7 +1267,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......@@ -1249,7 +1283,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......@@ -1262,7 +1296,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) {
return std::make_pair(-1, 0.0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment