Unverified Commit a40672a4 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Support set seed for graphbolt. (#6613)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 969276eb
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <graphbolt/unique_and_compact.h> #include <graphbolt/unique_and_compact.h>
#include "./index_select.h" #include "./index_select.h"
#include "./random.h"
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
...@@ -68,6 +69,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -68,6 +69,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("unique_and_compact", &UniqueAndCompact); m.def("unique_and_compact", &UniqueAndCompact);
m.def("isin", &IsIn); m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect); m.def("index_select", &ops::IndexSelect);
m.def("set_seed", &RandomEngine::SetManualSeed);
} }
} // namespace sampling } // namespace sampling
......
/**
* Copyright (c) 2023 by Contributors
* @file random.cc
* @brief Random Engine.
*/
#include "./random.h"
#include <torch/torch.h>
namespace graphbolt {
namespace {
// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
static int num_threads = 0;
static std::mutex mutex;
static thread_local int id = -1;
if (id == -1) {
std::lock_guard<std::mutex> guard(mutex);
id = num_threads;
num_threads++;
}
return id;
}
}; // namespace
std::optional<uint64_t> RandomEngine::manual_seed;
/** @brief Constructor with default seed. */
RandomEngine::RandomEngine() {
std::random_device rd;
uint64_t seed = manual_seed.value_or(rd());
SetSeed(seed);
}
/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed) { RandomEngine(seed, GetThreadId()); }
/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed, uint64_t stream) {
SetSeed(seed, stream);
}
/** @brief Get the thread-local random number generator instance. */
RandomEngine* RandomEngine::ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
/** @brief Set the seed. */
void RandomEngine::SetSeed(uint64_t seed) { SetSeed(seed, GetThreadId()); }
/** @brief Set the seed. */
void RandomEngine::SetSeed(uint64_t seed, uint64_t stream) {
rng_.seed(seed, stream);
}
/** @brief Manually fix the seed. */
void RandomEngine::SetManualSeed(int64_t seed) { manual_seed = seed; }
} // namespace graphbolt
...@@ -10,55 +10,35 @@ ...@@ -10,55 +10,35 @@
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <optional>
#include <pcg_random.hpp> #include <pcg_random.hpp>
#include <random> #include <random>
#include <thread> #include <thread>
namespace graphbolt { namespace graphbolt {
namespace {
// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
static int num_threads = 0;
static std::mutex mutex;
static thread_local int id = -1;
if (id == -1) {
std::lock_guard<std::mutex> guard(mutex);
id = num_threads;
num_threads++;
}
return id;
}
}; // namespace
/** /**
* @brief Thread-local Random Number Generator class. * @brief Thread-local Random Number Generator class.
*/ */
class RandomEngine { class RandomEngine {
public: public:
/** @brief Constructor with default seed. */ /** @brief Constructor with default seed. */
RandomEngine() { RandomEngine();
std::random_device rd;
SetSeed(rd());
}
/** @brief Constructor with given seed. */ /** @brief Constructor with given seed. */
explicit RandomEngine(uint64_t seed, uint64_t stream = GetThreadId()) { explicit RandomEngine(uint64_t seed);
SetSeed(seed, stream); explicit RandomEngine(uint64_t seed, uint64_t stream);
}
/** @brief Get the thread-local random number generator instance. */ /** @brief Get the thread-local random number generator instance. */
static RandomEngine* ThreadLocal() { static RandomEngine* ThreadLocal();
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
/** @brief Set the seed. */ /** @brief Set the seed. */
void SetSeed(uint64_t seed, uint64_t stream = GetThreadId()) { void SetSeed(uint64_t seed);
rng_.seed(seed, stream); void SetSeed(uint64_t seed, uint64_t stream);
}
/** @brief Manually fix the seed. */
static std::optional<uint64_t> manual_seed;
static void SetManualSeed(int64_t seed);
/** /**
* @brief Generate a uniform random integer in [low, high). * @brief Generate a uniform random integer in [low, high).
......
...@@ -16,12 +16,24 @@ __all__ = [ ...@@ -16,12 +16,24 @@ __all__ = [
"CopyTo", "CopyTo",
"isin", "isin",
"CSCFormatBase", "CSCFormatBase",
"seed",
] ]
CANONICAL_ETYPE_DELIMITER = ":" CANONICAL_ETYPE_DELIMITER = ":"
ORIGINAL_EDGE_ID = "_ORIGINAL_EDGE_ID" ORIGINAL_EDGE_ID = "_ORIGINAL_EDGE_ID"
def seed(val):
"""Set the random seed of Graphbolt.
Parameters
----------
val : int
The seed.
"""
torch.ops.graphbolt.set_seed(val)
def isin(elements, test_elements): def isin(elements, test_elements):
"""Tests if each element of elements is in test_elements. Returns a boolean """Tests if each element of elements is in test_elements. Returns a boolean
tensor of the same shape as elements that is True for elements in tensor of the same shape as elements that is True for elements in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment