"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c4d66200b7a747a3657e81c188a5d833f23a7d47"
Unverified Commit 4e74dc86 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix curand (#3077)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent a0390dde
...@@ -24,6 +24,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") ...@@ -24,6 +24,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < omp_get_max_threads(); ++i) { for (int i = 0; i < omp_get_max_threads(); ++i) {
RandomEngine::ThreadLocal()->SetSeed(seed); RandomEngine::ThreadLocal()->SetSeed(seed);
}
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto* thr_entry = CUDAThreadEntry::ThreadLocal(); auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) { if (!thr_entry->curand_gen) {
...@@ -31,9 +32,8 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") ...@@ -31,9 +32,8 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
} }
CURAND_CALL(curandSetPseudoRandomGeneratorSeed( CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen, thr_entry->curand_gen,
static_cast<uint64_t>(seed + GetThreadId()))); static_cast<uint64_t>(seed)));
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}
}); });
DGL_REGISTER_GLOBAL("rng._CAPI_Choice") DGL_REGISTER_GLOBAL("rng._CAPI_Choice")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment