Commit 4306b22c authored by Guolin Ke's avatar Guolin Ke
Browse files

speed up the random generate

parent dd316895
...@@ -16,16 +16,17 @@ public: ...@@ -16,16 +16,17 @@ public:
/*! /*!
* \brief Constructor, with random seed * \brief Constructor, with random seed
*/ */
Random() Random() {
:distribution_zero_to_one_(0.0, 1.0) {
std::random_device rd; std::random_device rd;
generator_ = std::mt19937(rd()); auto genrator = std::mt19937(rd());
std::uniform_int_distribution<int> distribution(0, x);
x = static_cast<unsigned int>(distribution(genrator));
} }
/*! /*!
* \brief Constructor, with specific seed * \brief Constructor, with specific seed
*/ */
Random(int seed) Random(int seed) {
:generator_(seed), distribution_zero_to_one_(0.0, 1.0) { x = static_cast<unsigned int>(seed);
} }
/*! /*!
* \brief Generate random integer * \brief Generate random integer
...@@ -33,10 +34,8 @@ public: ...@@ -33,10 +34,8 @@ public:
* \param upper_bound upper bound * \param upper_bound upper bound
* \return The random integer between [lower_bound, upper_bound) * \return The random integer between [lower_bound, upper_bound)
*/ */
inline int64_t NextInt(int64_t lower_bound, int64_t upper_bound) { inline int NextInt(int lower_bound, int upper_bound) {
// get random interge in [a,b) return (next()) % (upper_bound - lower_bound + 1) + lower_bound;
std::uniform_int_distribution<int64_t> distribution(lower_bound, upper_bound - 1);
return distribution(generator_);
} }
/*! /*!
* \brief Generate random float data * \brief Generate random float data
...@@ -44,7 +43,7 @@ public: ...@@ -44,7 +43,7 @@ public:
*/ */
inline double NextDouble() { inline double NextDouble() {
// get random float in [0,1) // get random float in [0,1)
return distribution_zero_to_one_(generator_); return static_cast<double>(next() % 2047) / 2047.0f;
} }
/*! /*!
* \brief Sample K data from {0,1,...,N-1} * \brief Sample K data from {0,1,...,N-1}
...@@ -66,12 +65,22 @@ public: ...@@ -66,12 +65,22 @@ public:
return ret; return ret;
} }
private: private:
/*! \brief Random generator */ unsigned next() {
std::mt19937 generator_; x ^= x << 16;
/*! \brief Cache distribution for [0.0, 1.0) */ x ^= x >> 5;
std::uniform_real_distribution<double> distribution_zero_to_one_; x ^= x << 1;
auto t = x;
x = y;
y = z;
z = t ^ x ^ y;
return z;
}
unsigned int x = 123456789;
unsigned int y = 362436069;
unsigned int z = 521288629;
}; };
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_UTILS_RANDOM_H_ #endif // LightGBM_UTILS_RANDOM_H_
...@@ -158,7 +158,7 @@ public: ...@@ -158,7 +158,7 @@ public:
++cur_sample_cnt; ++cur_sample_cnt;
} }
else { else {
const size_t idx = static_cast<size_t>(random.NextInt(0, line_idx + 1)); const size_t idx = static_cast<size_t>(random.NextInt(0, static_cast<int>(line_idx + 1)));
if (idx < static_cast<size_t>(sample_cnt)) { if (idx < static_cast<size_t>(sample_cnt)) {
out_sampled_data->operator[](idx) = std::string(buffer, size); out_sampled_data->operator[](idx) = std::string(buffer, size);
} }
...@@ -198,7 +198,7 @@ public: ...@@ -198,7 +198,7 @@ public:
++cur_sample_cnt; ++cur_sample_cnt;
} }
else { else {
const size_t idx = static_cast<size_t>(random.NextInt(0, out_used_data_indices->size())); const size_t idx = static_cast<size_t>(random.NextInt(0, static_cast<int>(out_used_data_indices->size())));
if (idx < static_cast<size_t>(sample_cnt) ) { if (idx < static_cast<size_t>(sample_cnt) ) {
out_sampled_data->operator[](idx) = std::string(buffer, size); out_sampled_data->operator[](idx) = std::string(buffer, size);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment