Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
31388573
Unverified
Commit
31388573
authored
Oct 17, 2022
by
dkbhaskaran
Committed by
GitHub
Oct 17, 2022
Browse files
Enable ROCm builds (#282)
parent
d1aee184
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
7 deletions
+50
-7
csrc/cuda/atomics.cuh
csrc/cuda/atomics.cuh
+1
-1
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+11
-0
csrc/version.cpp
csrc/version.cpp
+8
-0
setup.py
setup.py
+30
-6
No files found.
csrc/cuda/atomics.cuh
View file @
31388573
...
...
@@ -5,7 +5,7 @@ static inline __device__ void atomAdd(float *address, float val) {
}
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
#if
defined(USE_ROCM) || (
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
)
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
...
...
csrc/cuda/utils.cuh
View file @
31388573
...
...
@@ -16,3 +16,14 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const
unsigned
int
delta
)
{
return
__shfl_down_sync
(
mask
,
var
.
operator
__half
(),
delta
);
}
#ifdef USE_ROCM
__device__
__inline__
at
::
Half
__ldg
(
const
at
::
Half
*
ptr
)
{
return
__ldg
(
reinterpret_cast
<
const
__half
*>
(
ptr
));
}
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
#else
#define SHFL_UP_SYNC __shfl_up_sync
#define SHFL_DOWN_SYNC __shfl_down_sync
#endif
csrc/version.cpp
View file @
31388573
...
...
@@ -4,8 +4,12 @@
#include <torch/script.h>
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#endif
#endif
#include "macros.h"
...
...
@@ -22,7 +26,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
namespace
sparse
{
SPARSE_API
int64_t
cuda_version
()
noexcept
{
#ifdef WITH_CUDA
#ifdef USE_ROCM
return
HIP_VERSION
;
#else
return
CUDA_VERSION
;
#endif
#else
return
-
1
;
#endif
...
...
setup.py
View file @
31388573
...
...
@@ -18,7 +18,9 @@ from torch.utils.cpp_extension import (
__version__
=
'0.6.15'
URL
=
'https://github.com/rusty1s/pytorch_sparse'
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
WITH_CUDA
=
False
if
torch
.
cuda
.
is_available
():
WITH_CUDA
=
CUDA_HOME
is
not
None
or
torch
.
version
.
hip
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
suffices
=
[
'cuda'
,
'cpu'
]
...
...
@@ -40,9 +42,12 @@ def get_extensions():
extensions_dir
=
osp
.
join
(
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
# remove generated 'hip' files, in case of rebuilds
main_files
=
[
path
for
path
in
main_files
if
'hip'
not
in
path
]
for
main
,
suffix
in
product
(
main_files
,
suffices
):
define_macros
=
[(
'WITH_PYTHON'
,
None
)]
undef_macros
=
[]
if
sys
.
platform
==
'win32'
:
define_macros
+=
[(
'torchsparse_EXPORTS'
,
None
)]
...
...
@@ -84,13 +89,26 @@ def get_extensions():
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
,
'-O2'
]
nvcc_flags
+=
[
'-O2'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
if
torch
.
version
.
hip
:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros
+=
[(
'USE_ROCM'
,
None
)]
undef_macros
+=
[
'__HIP_NO_HALF_CONVERSIONS__'
]
else
:
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
]
if
sys
.
platform
==
'win32'
:
extra_link_args
+=
[
'cusparse.lib'
]
if
torch
.
version
.
hip
:
if
sys
.
platform
==
'win32'
:
extra_link_args
+=
[
'hipsparse.lib'
]
else
:
extra_link_args
+=
[
'-lhipsparse'
,
'-l'
,
'hipsparse'
]
else
:
extra_link_args
+=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
if
sys
.
platform
==
'win32'
:
extra_link_args
+=
[
'cusparse.lib'
]
else
:
extra_link_args
+=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
]
...
...
@@ -111,6 +129,7 @@ def get_extensions():
sources
,
include_dirs
=
[
extensions_dir
,
phmap_dir
],
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
,
libraries
=
libraries
,
...
...
@@ -129,6 +148,11 @@ test_requires = [
'pytest-cov'
,
]
# work-around hipify abs paths
include_package_data
=
True
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
:
include_package_data
=
False
setup
(
name
=
'torch_sparse'
,
version
=
__version__
,
...
...
@@ -155,5 +179,5 @@ setup(
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
},
packages
=
find_packages
(),
include_package_data
=
True
,
include_package_data
=
include_package_data
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment