"src/vscode:/vscode.git/clone" did not exist on "378c264561d232dfe00d934409b64ab626a4915a"
Unverified Commit 1feec870 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Add CUDA context availability check before setting curand seed (#4223)

parent 3e26c3d1
......@@ -44,6 +44,12 @@ class DeviceAPI {
public:
/*! \brief virtual destructor */
virtual ~DeviceAPI() {}
/*!
* \brief Check whether the device is available.
*/
virtual bool IsAvailable() {
return true;
}
/*!
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
......
......@@ -29,6 +29,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
}
});
#ifdef DGL_USE_CUDA
if (DeviceAPI::Get(kDLGPU)->IsAvailable()) {
auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
......@@ -36,6 +37,7 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
}
#endif // DGL_USE_CUDA
});
......
......@@ -15,6 +15,23 @@ namespace runtime {
class CUDADeviceAPI final : public DeviceAPI {
public:
CUDADeviceAPI() {
int count;
auto err = cudaGetDeviceCount(&count);
switch (err) {
case cudaSuccess:
break;
default:
count = 0;
cudaGetLastError();
}
is_available_ = count > 0;
}
bool IsAvailable() final {
return is_available_;
}
void SetDevice(DGLContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment