Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1feec870
Unverified
Commit
1feec870
authored
Jul 09, 2022
by
Xin Yao
Committed by
GitHub
Jul 09, 2022
Browse files
[Bugfix] Add CUDA context availability check before setting curand seed (#4223)
parent
3e26c3d1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
6 deletions
+31
-6
include/dgl/runtime/device_api.h
include/dgl/runtime/device_api.h
+6
-0
src/random/random.cc
src/random/random.cc
+8
-6
src/runtime/cuda/cuda_device_api.cc
src/runtime/cuda/cuda_device_api.cc
+17
-0
No files found.
include/dgl/runtime/device_api.h
View file @
1feec870
...
@@ -44,6 +44,12 @@ class DeviceAPI {
...
@@ -44,6 +44,12 @@ class DeviceAPI {
public:
public:
/*! \brief virtual destructor */
/*! \brief virtual destructor */
virtual
~
DeviceAPI
()
{}
virtual
~
DeviceAPI
()
{}
/*!
* \brief Check whether the device is available.
*/
virtual
bool
IsAvailable
()
{
return
true
;
}
/*!
/*!
* \brief Set the environment device id to ctx
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
* \param ctx The context to be set.
...
...
src/random/random.cc
View file @
1feec870
...
@@ -29,13 +29,15 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
...
@@ -29,13 +29,15 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
}
}
});
});
#ifdef DGL_USE_CUDA
#ifdef DGL_USE_CUDA
auto
*
thr_entry
=
CUDAThreadEntry
::
ThreadLocal
();
if
(
DeviceAPI
::
Get
(
kDLGPU
)
->
IsAvailable
())
{
if
(
!
thr_entry
->
curand_gen
)
{
auto
*
thr_entry
=
CUDAThreadEntry
::
ThreadLocal
();
CURAND_CALL
(
curandCreateGenerator
(
&
thr_entry
->
curand_gen
,
CURAND_RNG_PSEUDO_DEFAULT
));
if
(
!
thr_entry
->
curand_gen
)
{
CURAND_CALL
(
curandCreateGenerator
(
&
thr_entry
->
curand_gen
,
CURAND_RNG_PSEUDO_DEFAULT
));
}
CURAND_CALL
(
curandSetPseudoRandomGeneratorSeed
(
thr_entry
->
curand_gen
,
static_cast
<
uint64_t
>
(
seed
)));
}
}
CURAND_CALL
(
curandSetPseudoRandomGeneratorSeed
(
thr_entry
->
curand_gen
,
static_cast
<
uint64_t
>
(
seed
)));
#endif // DGL_USE_CUDA
#endif // DGL_USE_CUDA
});
});
...
...
src/runtime/cuda/cuda_device_api.cc
View file @
1feec870
...
@@ -15,6 +15,23 @@ namespace runtime {
...
@@ -15,6 +15,23 @@ namespace runtime {
class
CUDADeviceAPI
final
:
public
DeviceAPI
{
class
CUDADeviceAPI
final
:
public
DeviceAPI
{
public:
public:
CUDADeviceAPI
()
{
int
count
;
auto
err
=
cudaGetDeviceCount
(
&
count
);
switch
(
err
)
{
case
cudaSuccess
:
break
;
default:
count
=
0
;
cudaGetLastError
();
}
is_available_
=
count
>
0
;
}
bool
IsAvailable
()
final
{
return
is_available_
;
}
void
SetDevice
(
DGLContext
ctx
)
final
{
void
SetDevice
(
DGLContext
ctx
)
final
{
CUDA_CALL
(
cudaSetDevice
(
ctx
.
device_id
));
CUDA_CALL
(
cudaSetDevice
(
ctx
.
device_id
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment