"...tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "c676a7ea19aa7518159b3fbc5035092bd4d0aa39"
Unverified Commit 3be611e7 authored by Chen Yufei's avatar Chen Yufei Committed by GitHub
Browse files

[refactor] Use `CreateSampleIndices()` in `c_api.cpp` (#4478)

This removes code duplication for creating sample indices.
parent 0012fc28
...@@ -1112,10 +1112,8 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, ...@@ -1112,10 +1112,8 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.data_random_seed); auto sample_indices = CreateSampleIndices(total_nrow, config);
int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(sample_indices.size());
auto sample_indices = rand.Sample(total_nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol); std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol); std::vector<std::vector<int>> sample_idx(ncol);
...@@ -1198,10 +1196,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -1198,10 +1196,8 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
int32_t nrow = static_cast<int32_t>(nindptr - 1); int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.data_random_seed); auto sample_indices = CreateSampleIndices(nrow, config);
int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(sample_indices.size());
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(num_col); std::vector<std::vector<double>> sample_values(num_col);
std::vector<std::vector<int>> sample_idx(num_col); std::vector<std::vector<int>> sample_idx(num_col);
for (size_t i = 0; i < sample_indices.size(); ++i) { for (size_t i = 0; i < sample_indices.size(); ++i) {
...@@ -1267,10 +1263,8 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, ...@@ -1267,10 +1263,8 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
int32_t nrow = num_rows; int32_t nrow = num_rows;
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.data_random_seed); auto sample_indices = CreateSampleIndices(nrow, config);
int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(sample_indices.size());
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(num_col); std::vector<std::vector<double>> sample_values(num_col);
std::vector<std::vector<int>> sample_idx(num_col); std::vector<std::vector<int>> sample_idx(num_col);
// local buffer to re-use memory // local buffer to re-use memory
...@@ -1341,10 +1335,8 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -1341,10 +1335,8 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int32_t nrow = static_cast<int32_t>(num_row); int32_t nrow = static_cast<int32_t>(num_row);
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(config.data_random_seed); auto sample_indices = CreateSampleIndices(nrow, config);
int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(sample_indices.size());
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol_ptr - 1); std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1); std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
OMP_INIT_EX(); OMP_INIT_EX();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment