Unverified Commit 1feec870 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Add CUDA context availability check before setting curand seed (#4223)

parent 3e26c3d1
...@@ -44,6 +44,12 @@ class DeviceAPI { ...@@ -44,6 +44,12 @@ class DeviceAPI {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~DeviceAPI() {} virtual ~DeviceAPI() {}
/*!
* \brief Check whether the device is available.
*/
virtual bool IsAvailable() {
return true;
}
/*! /*!
* \brief Set the environment device id to ctx * \brief Set the environment device id to ctx
* \param ctx The context to be set. * \param ctx The context to be set.
......
...@@ -29,13 +29,15 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") ...@@ -29,13 +29,15 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
} }
}); });
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
auto* thr_entry = CUDAThreadEntry::ThreadLocal(); if (DeviceAPI::Get(kDLGPU)->IsAvailable()) {
if (!thr_entry->curand_gen) { auto* thr_entry = CUDAThreadEntry::ThreadLocal();
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT)); if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
} }
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
#endif // DGL_USE_CUDA #endif // DGL_USE_CUDA
}); });
......
...@@ -15,6 +15,23 @@ namespace runtime { ...@@ -15,6 +15,23 @@ namespace runtime {
class CUDADeviceAPI final : public DeviceAPI { class CUDADeviceAPI final : public DeviceAPI {
public: public:
CUDADeviceAPI() {
int count;
auto err = cudaGetDeviceCount(&count);
switch (err) {
case cudaSuccess:
break;
default:
count = 0;
cudaGetLastError();
}
is_available_ = count > 0;
}
bool IsAvailable() final {
return is_available_;
}
void SetDevice(DGLContext ctx) final { void SetDevice(DGLContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment