Commit 4306b22c authored by Guolin Ke's avatar Guolin Ke
Browse files

speed up the random generate

parent dd316895
......@@ -16,16 +16,17 @@ public:
/*!
* \brief Constructor, with random seed
*/
Random()
:distribution_zero_to_one_(0.0, 1.0) {
Random() {
std::random_device rd;
generator_ = std::mt19937(rd());
auto genrator = std::mt19937(rd());
std::uniform_int_distribution<int> distribution(0, x);
x = static_cast<unsigned int>(distribution(genrator));
}
/*!
* \brief Constructor, with specific seed
*/
Random(int seed)
:generator_(seed), distribution_zero_to_one_(0.0, 1.0) {
Random(int seed) {
x = static_cast<unsigned int>(seed);
}
/*!
* \brief Generate random integer
......@@ -33,10 +34,8 @@ public:
* \param upper_bound upper bound
* \return The random integer between [lower_bound, upper_bound)
*/
inline int64_t NextInt(int64_t lower_bound, int64_t upper_bound) {
// get random interge in [a,b)
std::uniform_int_distribution<int64_t> distribution(lower_bound, upper_bound - 1);
return distribution(generator_);
inline int NextInt(int lower_bound, int upper_bound) {
return (next()) % (upper_bound - lower_bound + 1) + lower_bound;
}
/*!
* \brief Generate random float data
......@@ -44,7 +43,7 @@ public:
*/
inline double NextDouble() {
// get random float in [0,1)
return distribution_zero_to_one_(generator_);
return static_cast<double>(next() % 2047) / 2047.0f;
}
/*!
* \brief Sample K data from {0,1,...,N-1}
......@@ -66,12 +65,22 @@ public:
return ret;
}
private:
/*! \brief Random generator */
std::mt19937 generator_;
/*! \brief Cache distribution for [0.0, 1.0) */
std::uniform_real_distribution<double> distribution_zero_to_one_;
unsigned next() {
x ^= x << 16;
x ^= x >> 5;
x ^= x << 1;
auto t = x;
x = y;
y = z;
z = t ^ x ^ y;
return z;
}
unsigned int x = 123456789;
unsigned int y = 362436069;
unsigned int z = 521288629;
};
} // namespace LightGBM
#endif // LightGBM_UTILS_RANDOM_H_
......@@ -158,7 +158,7 @@ public:
++cur_sample_cnt;
}
else {
const size_t idx = static_cast<size_t>(random.NextInt(0, line_idx + 1));
const size_t idx = static_cast<size_t>(random.NextInt(0, static_cast<int>(line_idx + 1)));
if (idx < static_cast<size_t>(sample_cnt)) {
out_sampled_data->operator[](idx) = std::string(buffer, size);
}
......@@ -198,7 +198,7 @@ public:
++cur_sample_cnt;
}
else {
const size_t idx = static_cast<size_t>(random.NextInt(0, out_used_data_indices->size()));
const size_t idx = static_cast<size_t>(random.NextInt(0, static_cast<int>(out_used_data_indices->size())));
if (idx < static_cast<size_t>(sample_cnt) ) {
out_sampled_data->operator[](idx) = std::string(buffer, size);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment