Unverified Commit 25eabbb4 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Setting manual seed also sets the seed now. (#6900)

parent 107b4347
......@@ -28,11 +28,13 @@ inline uint32_t GetThreadId() {
}; // namespace
std::mutex RandomEngine::manual_seed_mutex;
std::optional<uint64_t> RandomEngine::manual_seed;
/** @brief Constructor with default seed. */
RandomEngine::RandomEngine() {
std::random_device rd;
std::lock_guard lock(manual_seed_mutex);
uint64_t seed = manual_seed.value_or(rd());
SetSeed(seed);
}
......@@ -59,6 +61,11 @@ void RandomEngine::SetSeed(uint64_t seed, uint64_t stream) {
}
/** @brief Manually fix the seed. */
void RandomEngine::SetManualSeed(int64_t seed) { manual_seed = seed; }
void RandomEngine::SetManualSeed(int64_t seed) {
// Intentionally set the seed for current thread also.
RandomEngine::ThreadLocal()->SetSeed(seed);
std::lock_guard lock(manual_seed_mutex);
manual_seed = seed;
}
} // namespace graphbolt
......@@ -10,6 +10,7 @@
#include <dmlc/thread_local.h>
#include <mutex>
#include <optional>
#include <pcg_random.hpp>
#include <random>
......@@ -36,6 +37,9 @@ class RandomEngine {
void SetSeed(uint64_t seed);
void SetSeed(uint64_t seed, uint64_t stream);
/** @brief Protect manual seed accesses. */
static std::mutex manual_seed_mutex;
/** @brief Manually fix the seed. */
static std::optional<uint64_t> manual_seed;
static void SetManualSeed(int64_t seed);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment