Commit 8a0f07ae authored by Guolin Ke's avatar Guolin Ke
Browse files

more efficient algorithm to sample k from n.

parent 8dbc48fd
...@@ -64,15 +64,32 @@ public: ...@@ -64,15 +64,32 @@ public:
*/ */
inline std::vector<int> Sample(int N, int K) { inline std::vector<int> Sample(int N, int K) {
std::vector<int> ret; std::vector<int> ret;
ret.reserve(K);
if (K > N || K < 0) { if (K > N || K < 0) {
return ret; return ret;
} else if (K == N) {
for (int i = 0; i < N; ++i) {
ret.push_back(i);
} }
} else if (K > N / 2) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
double prob = (K - ret.size()) / static_cast<double>(N - i); double prob = (K - ret.size()) / static_cast<double>(N - i);
if (NextFloat() < prob) { if (NextFloat() < prob) {
ret.push_back(i); ret.push_back(i);
} }
} }
} else {
int min_step = 1;
int avg_step = N / K;
int max_step = 2 * avg_step - min_step;
int start = -1;
for (int i = 0; i < K; ++i) {
int step = NextShort(min_step, max_step + 1);
start += step;
if (start >= N) { break; }
ret.push_back(start);
}
}
return ret; return ret;
} }
private: private:
......
...@@ -407,8 +407,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, ...@@ -407,8 +407,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(io_config.data_random_seed); Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol); std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol); std::vector<std::vector<int>> sample_idx(ncol);
for (size_t i = 0; i < sample_indices.size(); ++i) { for (size_t i = 0; i < sample_indices.size(); ++i) {
...@@ -465,8 +466,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -465,8 +466,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(io_config.data_random_seed); Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values; std::vector<std::vector<double>> sample_values;
std::vector<std::vector<int>> sample_idx; std::vector<std::vector<int>> sample_idx;
for (size_t i = 0; i < sample_indices.size(); ++i) { for (size_t i = 0; i < sample_indices.size(); ++i) {
...@@ -527,8 +529,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -527,8 +529,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
if (reference == nullptr) { if (reference == nullptr) {
// sample data first // sample data first
Random rand(io_config.data_random_seed); Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt); int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt); auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol_ptr - 1); std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1); std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment