"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "81d764877523e491af03d3e37f6b295f9d8a6cbb"
Unverified Commit 6ba6dd60 authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[GraphBolt] Add Random Engine for Sampling (#5905)

parent 40dcc715
......@@ -40,7 +40,9 @@ file(GLOB BOLT_SRC ${BOLT_DIR}/*.cc)
add_library(${LIB_GRAPHBOLT_NAME} SHARED ${BOLT_SRC} ${BOLT_HEADERS})
target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${BOLT_DIR}
${BOLT_HEADERS})
${BOLT_HEADERS}
"../third_party/dmlc-core/include"
"../third_party/pcg/include")
target_link_libraries(${LIB_GRAPHBOLT_NAME} "${TORCH_LIBRARIES}")
# The Torch CMake configuration only sets up the path for the MKL library when
......
/**
* Copyright (c) 2023 by Contributors
*
* @file random.h
* @brief Random Engine class.
*/
#ifndef GRAPHBOLT_RANDOM_H_
#define GRAPHBOLT_RANDOM_H_
#include <dmlc/thread_local.h>
#include <pcg_random.hpp>
#include <random>
#include <thread>
namespace graphbolt {
namespace {
// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
static int num_threads = 0;
static std::mutex mutex;
static thread_local int id = -1;
if (id == -1) {
std::lock_guard<std::mutex> guard(mutex);
id = num_threads;
num_threads++;
}
return id;
}
}; // namespace
/**
* @brief Thread-local Random Number Generator class.
*/
class RandomEngine {
public:
/** @brief Constructor with default seed. */
RandomEngine() {
std::random_device rd;
SetSeed(rd());
}
/** @brief Constructor with given seed. */
explicit RandomEngine(uint64_t seed, uint64_t stream = GetThreadId()) {
SetSeed(seed, stream);
}
/** @brief Get the thread-local random number generator instance. */
static RandomEngine* ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
/** @brief Set the seed. */
void SetSeed(uint64_t seed, uint64_t stream = GetThreadId()) {
rng_.seed(seed, stream);
}
/**
* @brief Generate a uniform random integer in [low, high).
*/
template <typename T>
T RandInt(T lower, T upper) {
std::uniform_int_distribution<T> dist(lower, upper - 1);
return dist(rng_);
}
private:
pcg32 rng_;
};
} // namespace graphbolt
#endif // GRAPHBOLT_RANDOM_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment