Unverified Commit 4e74dc86 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix curand (#3077)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent a0390dde
...@@ -24,16 +24,16 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") ...@@ -24,16 +24,16 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < omp_get_max_threads(); ++i) { for (int i = 0; i < omp_get_max_threads(); ++i) {
RandomEngine::ThreadLocal()->SetSeed(seed); RandomEngine::ThreadLocal()->SetSeed(seed);
}
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto* thr_entry = CUDAThreadEntry::ThreadLocal(); auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) { if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT)); CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed + GetThreadId())));
#endif // DGL_USE_CUDA
} }
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
#endif // DGL_USE_CUDA
}); });
DGL_REGISTER_GLOBAL("rng._CAPI_Choice") DGL_REGISTER_GLOBAL("rng._CAPI_Choice")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment