"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d8854b8d5474676b94ae583113e9e67d670b11c5"
Unverified Commit 1feec870 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Add CUDA context availability check before setting curand seed (#4223)

parent 3e26c3d1
......@@ -44,6 +44,12 @@ class DeviceAPI {
public:
/*! \brief virtual destructor */
virtual ~DeviceAPI() {}
/*!
* \brief Check whether the device is available.
*/
virtual bool IsAvailable() {
return true;
}
/*!
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
......
......@@ -29,13 +29,15 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
}
});
#ifdef DGL_USE_CUDA
auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
if (DeviceAPI::Get(kDLGPU)->IsAvailable()) {
auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
#endif // DGL_USE_CUDA
});
......
......@@ -15,6 +15,23 @@ namespace runtime {
class CUDADeviceAPI final : public DeviceAPI {
public:
CUDADeviceAPI() {
int count;
auto err = cudaGetDeviceCount(&count);
switch (err) {
case cudaSuccess:
break;
default:
count = 0;
cudaGetLastError();
}
is_available_ = count > 0;
}
bool IsAvailable() final {
return is_available_;
}
void SetDevice(DGLContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment