"vscode:/vscode.git/clone" did not exist on "236490719327418acc2fa3423f690c725b88b6bb"
Unverified Commit a40672a4 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Support set seed for graphbolt. (#6613)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 969276eb
......@@ -10,6 +10,7 @@
#include <graphbolt/unique_and_compact.h>
#include "./index_select.h"
#include "./random.h"
namespace graphbolt {
namespace sampling {
......@@ -68,6 +69,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("unique_and_compact", &UniqueAndCompact);
m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect);
m.def("set_seed", &RandomEngine::SetManualSeed);
}
} // namespace sampling
......
/**
* Copyright (c) 2023 by Contributors
* @file random.cc
* @brief Random Engine.
*/
#include "./random.h"
#include <torch/torch.h>
namespace graphbolt {
namespace {
// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
static int num_threads = 0;
static std::mutex mutex;
static thread_local int id = -1;
if (id == -1) {
std::lock_guard<std::mutex> guard(mutex);
id = num_threads;
num_threads++;
}
return id;
}
}; // namespace
std::optional<uint64_t> RandomEngine::manual_seed;
/** @brief Constructor with default seed. */
RandomEngine::RandomEngine() {
std::random_device rd;
uint64_t seed = manual_seed.value_or(rd());
SetSeed(seed);
}
/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed) { RandomEngine(seed, GetThreadId()); }
/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed, uint64_t stream) {
SetSeed(seed, stream);
}
/** @brief Get the thread-local random number generator instance. */
RandomEngine* RandomEngine::ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
/** @brief Set the seed. */
void RandomEngine::SetSeed(uint64_t seed) { SetSeed(seed, GetThreadId()); }
/** @brief Set the seed. */
void RandomEngine::SetSeed(uint64_t seed, uint64_t stream) {
rng_.seed(seed, stream);
}
/** @brief Manually fix the seed. */
void RandomEngine::SetManualSeed(int64_t seed) { manual_seed = seed; }
} // namespace graphbolt
......@@ -10,55 +10,35 @@
#include <dmlc/thread_local.h>
#include <optional>
#include <pcg_random.hpp>
#include <random>
#include <thread>
namespace graphbolt {
namespace {
// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
static int num_threads = 0;
static std::mutex mutex;
static thread_local int id = -1;
if (id == -1) {
std::lock_guard<std::mutex> guard(mutex);
id = num_threads;
num_threads++;
}
return id;
}
}; // namespace
/**
* @brief Thread-local Random Number Generator class.
*/
class RandomEngine {
public:
/** @brief Constructor with default seed. */
RandomEngine() {
std::random_device rd;
SetSeed(rd());
}
RandomEngine();
/** @brief Constructor with given seed. */
explicit RandomEngine(uint64_t seed, uint64_t stream = GetThreadId()) {
SetSeed(seed, stream);
}
explicit RandomEngine(uint64_t seed);
explicit RandomEngine(uint64_t seed, uint64_t stream);
/** @brief Get the thread-local random number generator instance. */
static RandomEngine* ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
static RandomEngine* ThreadLocal();
/** @brief Set the seed. */
void SetSeed(uint64_t seed, uint64_t stream = GetThreadId()) {
rng_.seed(seed, stream);
}
void SetSeed(uint64_t seed);
void SetSeed(uint64_t seed, uint64_t stream);
/** @brief Manually fix the seed. */
static std::optional<uint64_t> manual_seed;
static void SetManualSeed(int64_t seed);
/**
* @brief Generate a uniform random integer in [low, high).
......
......@@ -16,12 +16,24 @@ __all__ = [
"CopyTo",
"isin",
"CSCFormatBase",
"seed",
]
CANONICAL_ETYPE_DELIMITER = ":"
ORIGINAL_EDGE_ID = "_ORIGINAL_EDGE_ID"
def seed(val):
"""Set the random seed of Graphbolt.
Parameters
----------
val : int
The seed.
"""
torch.ops.graphbolt.set_seed(val)
def isin(elements, test_elements):
"""Tests if each element of elements is in test_elements. Returns a boolean
tensor of the same shape as elements that is True for elements in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment