Commit 8aef4bf7 authored by Guolin Ke's avatar Guolin Ke
Browse files

more stable sampling K from N

parent 8aeceeb4
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include <set>
namespace LightGBM { namespace LightGBM {
...@@ -65,13 +66,13 @@ public: ...@@ -65,13 +66,13 @@ public:
inline std::vector<int> Sample(int N, int K) { inline std::vector<int> Sample(int N, int K) {
std::vector<int> ret; std::vector<int> ret;
ret.reserve(K); ret.reserve(K);
if (K > N || K < 0) { if (K > N || K <= 0) {
return ret; return ret;
} else if (K == N) { } else if (K == N) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
ret.push_back(i); ret.push_back(i);
} }
} else if (K > N / 2) { } else if (K > 1 && K > (N / std::log2(K))) {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
double prob = (K - ret.size()) / static_cast<double>(N - i); double prob = (K - ret.size()) / static_cast<double>(N - i);
if (NextFloat() < prob) { if (NextFloat() < prob) {
...@@ -79,15 +80,15 @@ public: ...@@ -79,15 +80,15 @@ public:
} }
} }
} else { } else {
int min_step = 1; std::set<int> sample_set;
int avg_step = N / K; while (sample_set.size() < K) {
int max_step = 2 * avg_step - min_step; int next = RandInt32() % N;
int start = -1; if (sample_set.count(next) == 0) {
for (int i = 0; i < K; ++i) { sample_set.insert(next);
int step = NextShort(min_step, max_step + 1); }
start += step; }
if (start >= N) { break; } for (auto iter = sample_set.begin(); iter != sample_set.end(); ++iter) {
ret.push_back(start); ret.push_back(*iter);
} }
} }
return ret; return ret;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment