Unverified Commit 25eabbb4 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Setting manual seed also sets the seed now. (#6900)

parent 107b4347
...@@ -28,11 +28,13 @@ inline uint32_t GetThreadId() { ...@@ -28,11 +28,13 @@ inline uint32_t GetThreadId() {
}; // namespace }; // namespace
std::mutex RandomEngine::manual_seed_mutex;
std::optional<uint64_t> RandomEngine::manual_seed; std::optional<uint64_t> RandomEngine::manual_seed;
/** @brief Constructor with default seed. */ /** @brief Constructor with default seed. */
RandomEngine::RandomEngine() { RandomEngine::RandomEngine() {
std::random_device rd; std::random_device rd;
std::lock_guard lock(manual_seed_mutex);
uint64_t seed = manual_seed.value_or(rd()); uint64_t seed = manual_seed.value_or(rd());
SetSeed(seed); SetSeed(seed);
} }
...@@ -59,6 +61,11 @@ void RandomEngine::SetSeed(uint64_t seed, uint64_t stream) { ...@@ -59,6 +61,11 @@ void RandomEngine::SetSeed(uint64_t seed, uint64_t stream) {
} }
/** @brief Manually fix the seed. */ /** @brief Manually fix the seed. */
void RandomEngine::SetManualSeed(int64_t seed) { manual_seed = seed; } void RandomEngine::SetManualSeed(int64_t seed) {
// Intentionally set the seed for current thread also.
RandomEngine::ThreadLocal()->SetSeed(seed);
std::lock_guard lock(manual_seed_mutex);
manual_seed = seed;
}
} // namespace graphbolt } // namespace graphbolt
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <mutex>
#include <optional> #include <optional>
#include <pcg_random.hpp> #include <pcg_random.hpp>
#include <random> #include <random>
...@@ -36,6 +37,9 @@ class RandomEngine { ...@@ -36,6 +37,9 @@ class RandomEngine {
void SetSeed(uint64_t seed); void SetSeed(uint64_t seed);
void SetSeed(uint64_t seed, uint64_t stream); void SetSeed(uint64_t seed, uint64_t stream);
/** @brief Protect manual seed accesses. */
static std::mutex manual_seed_mutex;
/** @brief Manually fix the seed. */ /** @brief Manually fix the seed. */
static std::optional<uint64_t> manual_seed; static std::optional<uint64_t> manual_seed;
static void SetManualSeed(int64_t seed); static void SetManualSeed(int64_t seed);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment