"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "2bbd532990643a2aea5f1b33ec2fe3d6e7def4b1"
Unverified Commit e9e587b6 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Refactor] C random number generator (#729)

* rng refactor

* fix bugs

* unit test

* remove setsize

* lint

* fix test

* use explicit instantiation instead of inlining

* stricter test

* use tvm solution

* moved python interface to dgl.random

* lint

* address comments

* make getthreadid an inline function
parent 34ac2ab4
.. _apirandom:
DGL Random Number Generator Controls
====================================
.. automodule:: dgl.random
.. autosummary::
:toctree: ../../generated
seed
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/random.h
* \brief Random number generators
*/
#ifndef DGL_RANDOM_H_
#define DGL_RANDOM_H_
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <random>
#include <thread>
namespace dgl {
using namespace dgl::runtime;
namespace {
inline uint32_t GetThreadId() {
static std::hash<std::thread::id> kThreadIdHasher;
return kThreadIdHasher(std::this_thread::get_id());
}
}; // namespace
/*!
* \brief Thread-local Random Number Generator class
*/
class RandomEngine {
public:
/*! \brief Constructor with default seed */
RandomEngine() {
std::random_device rd;
SetSeed(rd());
}
/*! \brief Constructor with given seed */
explicit RandomEngine(uint32_t seed) {
SetSeed(seed);
}
/*! \brief Get the thread-local random number generator instance */
static RandomEngine *ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
/*!
* \brief Set the seed of this random number generator
*/
void SetSeed(uint32_t seed) {
rng_.seed(seed + GetThreadId());
}
/*!
* \brief Generate a uniform random integer in [0, upper)
*/
template<typename T>
T RandInt(T upper) {
return RandInt<T>(0, upper);
}
/*!
* \brief Generate a uniform random integer in [lower, upper)
*/
template<typename T>
T RandInt(T lower, T upper) {
CHECK_LT(lower, upper);
std::uniform_int_distribution<T> dist(lower, upper - 1);
return dist(rng_);
}
/*!
* \brief Generate a uniform random float in [0, 1)
*/
template<typename T>
T Uniform() {
return Uniform<T>(0., 1.);
}
/*!
* \brief Generate a uniform random float in [lower, upper)
*/
template<typename T>
T Uniform(T lower, T upper) {
CHECK_LT(lower, upper);
std::uniform_real_distribution<T> dist(lower, upper);
return dist(rng_);
}
private:
std::mt19937 rng_;
};
}; // namespace dgl
#endif // DGL_RANDOM_H_
...@@ -13,24 +13,6 @@ ...@@ -13,24 +13,6 @@
#include "graph_interface.h" #include "graph_interface.h"
#include "nodeflow.h" #include "nodeflow.h"
#ifdef _MSC_VER
// rand in MS compiler works well in multi-threading.
inline int rand_r(unsigned *seed) {
return rand();
}
inline unsigned int randseed() {
unsigned int seed = time(nullptr);
srand(seed); // need to set seed manually since there's no rand_r
return seed;
}
#define _CRT_RAND_S
#else
inline unsigned int randseed() {
return time(nullptr);
}
#endif
namespace dgl { namespace dgl {
class ImmutableGraph; class ImmutableGraph;
......
...@@ -7,6 +7,7 @@ from . import function ...@@ -7,6 +7,7 @@ from . import function
from . import nn from . import nn
from . import contrib from . import contrib
from . import container from . import container
from . import random
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
......
from .sampler import NeighborSampler, LayerSampler from .sampler import NeighborSampler, LayerSampler
from .randomwalk import * from .randomwalk import *
from .dis_sampler import SamplerSender, SamplerReceiver from .dis_sampler import SamplerSender, SamplerReceiver
from .dis_sampler import SamplerPool from .dis_sampler import SamplerPool
\ No newline at end of file
"""Pyhton interfaces to DGL random number generators."""
from ._ffi.function import _init_api
def seed(val):
"""Set the seed of randomized methods in DGL.
The randomized methods include various samplers and random walk routines.
Parameters
----------
val : int
The seed
"""
_CAPI_SetSeed(val)
_init_api('dgl.rng', __name__)
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/random.h>
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cmath> #include <cmath>
...@@ -19,8 +20,7 @@ using namespace dgl::runtime; ...@@ -19,8 +20,7 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
using Walker = std::function<dgl_id_t( using Walker = std::function<dgl_id_t(const GraphInterface *, dgl_id_t)>;
const GraphInterface *, unsigned int *, dgl_id_t)>;
namespace { namespace {
...@@ -30,13 +30,12 @@ namespace { ...@@ -30,13 +30,12 @@ namespace {
*/ */
dgl_id_t WalkOneHop( dgl_id_t WalkOneHop(
const GraphInterface *gptr, const GraphInterface *gptr,
unsigned int *random_seed,
dgl_id_t cur) { dgl_id_t cur) {
const auto succ = gptr->SuccVec(cur); const auto succ = gptr->SuccVec(cur);
const size_t size = succ.size(); const size_t size = succ.size();
if (size == 0) if (size == 0)
return DGL_INVALID_ID; return DGL_INVALID_ID;
return succ[rand_r(random_seed) % size]; return succ[RandomEngine::ThreadLocal()->RandInt(size)];
} }
/*! /*!
...@@ -46,11 +45,10 @@ dgl_id_t WalkOneHop( ...@@ -46,11 +45,10 @@ dgl_id_t WalkOneHop(
template<int hops> template<int hops>
dgl_id_t WalkMultipleHops( dgl_id_t WalkMultipleHops(
const GraphInterface *gptr, const GraphInterface *gptr,
unsigned int *random_seed,
dgl_id_t cur) { dgl_id_t cur) {
dgl_id_t next; dgl_id_t next;
for (int i = 0; i < hops; ++i) { for (int i = 0; i < hops; ++i) {
if ((next = WalkOneHop(gptr, random_seed, cur)) == DGL_INVALID_ID) if ((next = WalkOneHop(gptr, cur)) == DGL_INVALID_ID)
return DGL_INVALID_ID; return DGL_INVALID_ID;
cur = next; cur = next;
} }
...@@ -72,7 +70,6 @@ IdArray GenericRandomWalk( ...@@ -72,7 +70,6 @@ IdArray GenericRandomWalk(
dgl_id_t *trace_data = static_cast<dgl_id_t *>(traces->data); dgl_id_t *trace_data = static_cast<dgl_id_t *>(traces->data);
// FIXME: does OpenMP work with exceptions? Especially without throwing SIGABRT? // FIXME: does OpenMP work with exceptions? Especially without throwing SIGABRT?
unsigned int random_seed = randseed();
dgl_id_t next; dgl_id_t next;
for (int64_t i = 0; i < num_nodes; ++i) { for (int64_t i = 0; i < num_nodes; ++i) {
...@@ -85,7 +82,7 @@ IdArray GenericRandomWalk( ...@@ -85,7 +82,7 @@ IdArray GenericRandomWalk(
for (int k = 0; k < kmax; ++k) { for (int k = 0; k < kmax; ++k) {
const int64_t offset = (i * num_traces + j) * kmax + k; const int64_t offset = (i * num_traces + j) * kmax + k;
trace_data[offset] = cur; trace_data[offset] = cur;
if ((next = walker(gptr, &random_seed, cur)) == DGL_INVALID_ID) if ((next = walker(gptr, cur)) == DGL_INVALID_ID)
LOG(FATAL) << "no successors from vertex " << cur; LOG(FATAL) << "no successors from vertex " << cur;
cur = next; cur = next;
} }
...@@ -107,12 +104,9 @@ RandomWalkTraces GenericRandomWalkWithRestart( ...@@ -107,12 +104,9 @@ RandomWalkTraces GenericRandomWalkWithRestart(
std::vector<size_t> trace_lengths, trace_counts, visit_counts; std::vector<size_t> trace_lengths, trace_counts, visit_counts;
const dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seeds->data); const dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seeds->data);
const uint64_t num_nodes = seeds->shape[0]; const uint64_t num_nodes = seeds->shape[0];
int64_t restart_bound = static_cast<int64_t>(restart_prob * RAND_MAX);
visit_counts.resize(gptr->NumVertices()); visit_counts.resize(gptr->NumVertices());
unsigned int random_seed = randseed();
for (uint64_t i = 0; i < num_nodes; ++i) { for (uint64_t i = 0; i < num_nodes; ++i) {
int stop = 0; int stop = 0;
size_t total_trace_length = 0; size_t total_trace_length = 0;
...@@ -130,10 +124,11 @@ RandomWalkTraces GenericRandomWalkWithRestart( ...@@ -130,10 +124,11 @@ RandomWalkTraces GenericRandomWalkWithRestart(
(++num_frequent_visited_nodes == max_frequent_visited_nodes)) (++num_frequent_visited_nodes == max_frequent_visited_nodes))
stop = 1; stop = 1;
if ((trace_length > 0) && (rand_r(&random_seed) < restart_bound)) if ((trace_length > 0) &&
(RandomEngine::ThreadLocal()->Uniform<double>() < restart_prob))
break; break;
if ((next = walker(gptr, &random_seed, cur)) == DGL_INVALID_ID) if ((next = walker(gptr, cur)) == DGL_INVALID_ID)
LOG(FATAL) << "no successors from vertex " << cur; LOG(FATAL) << "no successors from vertex " << cur;
cur = next; cur = next;
vertices.push_back(cur); vertices.push_back(cur);
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/random.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
...@@ -69,8 +70,8 @@ class ArrayHeap { ...@@ -69,8 +70,8 @@ class ArrayHeap {
/* /*
* Sample from arrayHeap * Sample from arrayHeap
*/ */
size_t Sample(unsigned int* seed) { size_t Sample() {
float xi = heap_[1] * (rand_r(seed)%100/101.0); float xi = heap_[1] * RandomEngine::ThreadLocal()->Uniform<float>();
int i = 1; int i = 1;
while (i < limit_) { while (i < limit_) {
i = i << 1; i = i << 1;
...@@ -85,10 +86,10 @@ class ArrayHeap { ...@@ -85,10 +86,10 @@ class ArrayHeap {
/* /*
* Sample a vector by given the size n * Sample a vector by given the size n
*/ */
void SampleWithoutReplacement(size_t n, std::vector<size_t>* samples, unsigned int* seed) { void SampleWithoutReplacement(size_t n, std::vector<size_t>* samples) {
// sample n elements // sample n elements
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
samples->at(i) = this->Sample(seed); samples->at(i) = this->Sample();
this->Delete(samples->at(i)); this->Delete(samples->at(i));
} }
} }
...@@ -103,10 +104,10 @@ class ArrayHeap { ...@@ -103,10 +104,10 @@ class ArrayHeap {
/* /*
* Uniformly sample integers from [0, set_size) without replacement. * Uniformly sample integers from [0, set_size) without replacement.
*/ */
void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out, unsigned int* seed) { void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
std::unordered_set<size_t> sampled_idxs; std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) { while (sampled_idxs.size() < num) {
sampled_idxs.insert(rand_r(seed) % set_size); sampled_idxs.insert(RandomEngine::ThreadLocal()->RandInt(set_size));
} }
out->clear(); out->clear();
out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end()); out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end());
...@@ -143,8 +144,7 @@ void GetUniformSample(const dgl_id_t* edge_id_list, ...@@ -143,8 +144,7 @@ void GetUniformSample(const dgl_id_t* edge_id_list,
const size_t ver_len, const size_t ver_len,
const size_t max_num_neighbor, const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver, std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge, std::vector<dgl_id_t>* out_edge) {
unsigned int* seed) {
// Copy vid_list to output // Copy vid_list to output
if (ver_len <= max_num_neighbor) { if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len); out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
...@@ -155,13 +155,12 @@ void GetUniformSample(const dgl_id_t* edge_id_list, ...@@ -155,13 +155,12 @@ void GetUniformSample(const dgl_id_t* edge_id_list,
std::vector<size_t> sorted_idxs; std::vector<size_t> sorted_idxs;
if (ver_len > max_num_neighbor * 2) { if (ver_len > max_num_neighbor * 2) {
sorted_idxs.reserve(max_num_neighbor); sorted_idxs.reserve(max_num_neighbor);
RandomSample(ver_len, max_num_neighbor, &sorted_idxs, seed); RandomSample(ver_len, max_num_neighbor, &sorted_idxs);
std::sort(sorted_idxs.begin(), sorted_idxs.end()); std::sort(sorted_idxs.begin(), sorted_idxs.end());
} else { } else {
std::vector<size_t> negate; std::vector<size_t> negate;
negate.reserve(ver_len - max_num_neighbor); negate.reserve(ver_len - max_num_neighbor);
RandomSample(ver_len, ver_len - max_num_neighbor, RandomSample(ver_len, ver_len - max_num_neighbor, &negate);
&negate, seed);
std::sort(negate.begin(), negate.end()); std::sort(negate.begin(), negate.end());
NegateArray(negate, ver_len, &sorted_idxs); NegateArray(negate, ver_len, &sorted_idxs);
} }
...@@ -185,8 +184,7 @@ void GetNonUniformSample(const float* probability, ...@@ -185,8 +184,7 @@ void GetNonUniformSample(const float* probability,
const size_t ver_len, const size_t ver_len,
const size_t max_num_neighbor, const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver, std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge, std::vector<dgl_id_t>* out_edge) {
unsigned int* seed) {
// Copy vid_list to output // Copy vid_list to output
if (ver_len <= max_num_neighbor) { if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len); out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
...@@ -200,7 +198,7 @@ void GetNonUniformSample(const float* probability, ...@@ -200,7 +198,7 @@ void GetNonUniformSample(const float* probability,
sp_prob[i] = probability[vid_list[i]]; sp_prob[i] = probability[vid_list[i]];
} }
ArrayHeap arrayHeap(sp_prob); ArrayHeap arrayHeap(sp_prob);
arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index, seed); arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index);
out_ver->resize(max_num_neighbor); out_ver->resize(max_num_neighbor);
out_edge->resize(max_num_neighbor); out_edge->resize(max_num_neighbor);
for (size_t i = 0; i < max_num_neighbor; ++i) { for (size_t i = 0; i < max_num_neighbor; ++i) {
...@@ -376,7 +374,6 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -376,7 +374,6 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
size_t num_neighbor, size_t num_neighbor,
const bool add_self_loop) { const bool add_self_loop) {
CHECK_EQ(graph->NumBits(), 64) << "32 bit graph is not supported yet"; CHECK_EQ(graph->NumBits(), 64) << "32 bit graph is not supported yet";
unsigned int time_seed = randseed();
const size_t num_seeds = seeds.size(); const size_t num_seeds = seeds.size();
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR(); auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t* val_list = static_cast<dgl_id_t*>(orig_csr->edge_ids()->data); const dgl_id_t* val_list = static_cast<dgl_id_t*>(orig_csr->edge_ids()->data);
...@@ -426,8 +423,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -426,8 +423,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
ver_len, ver_len,
num_neighbor, num_neighbor,
&tmp_sampled_src_list, &tmp_sampled_src_list,
&tmp_sampled_edge_list, &tmp_sampled_edge_list);
&time_seed);
} else { // non-uniform-sample } else { // non-uniform-sample
GetNonUniformSample(probability, GetNonUniformSample(probability,
val_list + *(indptr + dst_id), val_list + *(indptr + dst_id),
...@@ -435,8 +431,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -435,8 +431,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
ver_len, ver_len,
num_neighbor, num_neighbor,
&tmp_sampled_src_list, &tmp_sampled_src_list,
&tmp_sampled_edge_list, &tmp_sampled_edge_list);
&time_seed);
} }
// If we need to add self loop and it doesn't exist in the sampled neighbor list. // If we need to add self loop and it doesn't exist in the sampled neighbor list.
if (add_self_loop && std::find(tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(), if (add_self_loop && std::find(tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(),
...@@ -551,7 +546,6 @@ namespace { ...@@ -551,7 +546,6 @@ namespace {
size_t curr = 0; size_t curr = 0;
size_t next = node_mapping->size(); size_t next = node_mapping->size();
unsigned int rand_seed = randseed();
for (int64_t i = num_layers - 1; i >= 0; --i) { for (int64_t i = num_layers - 1; i >= 0; --i) {
const int64_t layer_size = layer_sizes_data[i]; const int64_t layer_size = layer_sizes_data[i];
std::unordered_set<dgl_id_t> candidate_set; std::unordered_set<dgl_id_t> candidate_set;
...@@ -567,7 +561,8 @@ namespace { ...@@ -567,7 +561,8 @@ namespace {
std::unordered_map<dgl_id_t, size_t> n_occurrences; std::unordered_map<dgl_id_t, size_t> n_occurrences;
auto n_candidates = candidate_vector.size(); auto n_candidates = candidate_vector.size();
for (int64_t j = 0; j != layer_size; ++j) { for (int64_t j = 0; j != layer_size; ++j) {
auto dst = candidate_vector[rand_r(&rand_seed) % n_candidates]; auto dst = candidate_vector[
RandomEngine::ThreadLocal()->RandInt(n_candidates)];
if (!n_occurrences.insert(std::make_pair(dst, 1)).second) { if (!n_occurrences.insert(std::make_pair(dst, 1)).second) {
++n_occurrences[dst]; ++n_occurrences[dst];
} }
......
/*!
* Copyright (c) 2017 by Contributors
* \file random.cc
* \brief Random number generator interfaces
*/
#include <dmlc/omp.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/random.h>
using namespace dgl::runtime;
namespace dgl {
DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
int seed = args[0];
#pragma omp parallel for
for (int i = 0; i < omp_get_max_threads(); ++i)
RandomEngine::ThreadLocal()->SetSeed(seed);
});
}; // namespace dgl
...@@ -155,6 +155,28 @@ def test_layer_sampler(): ...@@ -155,6 +155,28 @@ def test_layer_sampler():
_test_layer_sampler() _test_layer_sampler()
_test_layer_sampler(prefetch=True) _test_layer_sampler(prefetch=True)
def test_setseed():
g = generate_rand_graph(100)
nids = []
dgl.random.seed(42)
for subg in dgl.contrib.sampling.NeighborSampler(
g, 5, 3, num_hops=2, neighbor_type='in', num_workers=1):
nids.append(
tuple(tuple(F.asnumpy(subg.layer_parent_nid(i))) for i in range(3)))
# reinitialize
dgl.random.seed(42)
for i, subg in enumerate(dgl.contrib.sampling.NeighborSampler(
g, 5, 3, num_hops=2, neighbor_type='in', num_workers=1)):
item = tuple(tuple(F.asnumpy(subg.layer_parent_nid(i))) for i in range(3))
assert item == nids[i]
for i, subg in enumerate(dgl.contrib.sampling.NeighborSampler(
g, 5, 3, num_hops=2, neighbor_type='in', num_workers=4)):
pass
if __name__ == '__main__': if __name__ == '__main__':
test_create_full() test_create_full()
test_1neighbor_sampler_all() test_1neighbor_sampler_all()
...@@ -162,3 +184,4 @@ if __name__ == '__main__': ...@@ -162,3 +184,4 @@ if __name__ == '__main__':
test_1neighbor_sampler() test_1neighbor_sampler()
test_10neighbor_sampler() test_10neighbor_sampler()
test_layer_sampler() test_layer_sampler()
test_setseed()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment