Commit 8aef4bf7 authored by Guolin Ke's avatar Guolin Ke
Browse files

more stable sampling K from N

parent 8aeceeb4
......@@ -5,6 +5,7 @@
#include <random>
#include <vector>
#include <set>
namespace LightGBM {
......@@ -65,13 +66,13 @@ public:
inline std::vector<int> Sample(int N, int K) {
std::vector<int> ret;
ret.reserve(K);
if (K > N || K < 0) {
if (K > N || K <= 0) {
return ret;
} else if (K == N) {
for (int i = 0; i < N; ++i) {
ret.push_back(i);
}
} else if (K > N / 2) {
} else if (K > 1 && K > (N / std::log2(K))) {
for (int i = 0; i < N; ++i) {
double prob = (K - ret.size()) / static_cast<double>(N - i);
if (NextFloat() < prob) {
......@@ -79,15 +80,15 @@ public:
}
}
} else {
int min_step = 1;
int avg_step = N / K;
int max_step = 2 * avg_step - min_step;
int start = -1;
for (int i = 0; i < K; ++i) {
int step = NextShort(min_step, max_step + 1);
start += step;
if (start >= N) { break; }
ret.push_back(start);
std::set<int> sample_set;
while (sample_set.size() < K) {
int next = RandInt32() % N;
if (sample_set.count(next) == 0) {
sample_set.insert(next);
}
}
for (auto iter = sample_set.begin(); iter != sample_set.end(); ++iter) {
ret.push_back(*iter);
}
}
return ret;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment