Commit 8a0f07ae authored by Guolin Ke's avatar Guolin Ke
Browse files

more efficient algorithm to sample k from n.

parent 8dbc48fd
......@@ -64,14 +64,31 @@ public:
*/
inline std::vector<int> Sample(int N, int K) {
std::vector<int> ret;
ret.reserve(K);
if (K > N || K < 0) {
return ret;
}
for (int i = 0; i < N; ++i) {
double prob = (K - ret.size()) / static_cast<double>(N - i);
if (NextFloat() < prob) {
} else if (K == N) {
for (int i = 0; i < N; ++i) {
ret.push_back(i);
}
} else if (K > N / 2) {
for (int i = 0; i < N; ++i) {
double prob = (K - ret.size()) / static_cast<double>(N - i);
if (NextFloat() < prob) {
ret.push_back(i);
}
}
} else {
int min_step = 1;
int avg_step = N / K;
int max_step = 2 * avg_step - min_step;
int start = -1;
for (int i = 0; i < K; ++i) {
int step = NextShort(min_step, max_step + 1);
start += step;
if (start >= N) { break; }
ret.push_back(start);
}
}
return ret;
}
......
......@@ -407,8 +407,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
if (reference == nullptr) {
// sample data first
Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol);
for (size_t i = 0; i < sample_indices.size(); ++i) {
......@@ -465,8 +466,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
if (reference == nullptr) {
// sample data first
Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values;
std::vector<std::vector<int>> sample_idx;
for (size_t i = 0; i < sample_indices.size(); ++i) {
......@@ -527,8 +529,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
if (reference == nullptr) {
// sample data first
Random rand(io_config.data_random_seed);
const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
sample_cnt = static_cast<int>(sample_indices.size());
std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
#pragma omp parallel for schedule(static)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment