Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
2657fd9c
Commit
2657fd9c
authored
Oct 17, 2022
by
Dineshkumar Bhaskaran
Browse files
Enable ROCm builds
parent
18f48b73
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
4 deletions
+31
-4
csrc/cuda/atomics.cuh
csrc/cuda/atomics.cuh
+1
-1
csrc/version.cpp
csrc/version.cpp
+8
-0
setup.py
setup.py
+22
-3
No files found.
csrc/cuda/atomics.cuh
View file @
2657fd9c
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
static
inline
__device__
void
atomAdd
(
float
*
address
,
float
val
)
{
static
inline
__device__
void
atomAdd
(
float
*
address
,
float
val
)
{
atomicAdd
(
address
,
val
);
atomicAdd
(
address
,
val
);
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
#if
defined(USE_ROCM) || (
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
)
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
old
=
*
address_as_ull
;
...
...
csrc/version.cpp
View file @
2657fd9c
...
@@ -2,8 +2,12 @@
...
@@ -2,8 +2,12 @@
#include <torch/script.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#include <cuda.h>
#endif
#endif
#endif
#ifdef _WIN32
#ifdef _WIN32
#ifdef WITH_CUDA
#ifdef WITH_CUDA
...
@@ -15,7 +19,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
...
@@ -15,7 +19,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
int64_t
cuda_version
()
{
int64_t
cuda_version
()
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#ifdef USE_ROCM
return
HIP_VERSION
;
#else
return
CUDA_VERSION
;
return
CUDA_VERSION
;
#endif
#else
#else
return
-
1
;
return
-
1
;
#endif
#endif
...
...
setup.py
View file @
2657fd9c
...
@@ -14,7 +14,10 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
...
@@ -14,7 +14,10 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
__version__
=
'1.2.1'
__version__
=
'1.2.1'
URL
=
'https://github.com/rusty1s/pytorch_spline_conv'
URL
=
'https://github.com/rusty1s/pytorch_spline_conv'
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
WITH_CUDA
=
False
if
torch
.
cuda
.
is_available
():
WITH_CUDA
=
CUDA_HOME
is
not
None
or
torch
.
version
.
hip
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
suffices
=
[
'cuda'
,
'cpu'
]
suffices
=
[
'cuda'
,
'cpu'
]
...
@@ -31,9 +34,12 @@ def get_extensions():
...
@@ -31,9 +34,12 @@ def get_extensions():
extensions_dir
=
osp
.
join
(
'csrc'
)
extensions_dir
=
osp
.
join
(
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
# remove generated 'hip' files, in case of rebuilds
main_files
=
[
path
for
path
in
main_files
if
'hip'
not
in
path
]
for
main
,
suffix
in
product
(
main_files
,
suffices
):
for
main
,
suffix
in
product
(
main_files
,
suffices
):
define_macros
=
[]
define_macros
=
[]
undef_macros
=
[]
extra_compile_args
=
{
'cxx'
:
[
'-O2'
]}
extra_compile_args
=
{
'cxx'
:
[
'-O2'
]}
if
not
os
.
name
==
'nt'
:
# Not on Windows:
if
not
os
.
name
==
'nt'
:
# Not on Windows:
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
...
@@ -59,8 +65,15 @@ def get_extensions():
...
@@ -59,8 +65,15 @@ def get_extensions():
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
,
'-O2'
]
nvcc_flags
+=
[
'-O2'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
if
torch
.
version
.
hip
:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros
+=
[(
'USE_ROCM'
,
None
)]
undef_macros
+=
[
'__HIP_NO_HALF_CONVERSIONS__'
]
else
:
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
]
sources
=
[
main
]
...
@@ -79,6 +92,7 @@ def get_extensions():
...
@@ -79,6 +92,7 @@ def get_extensions():
sources
,
sources
,
include_dirs
=
[
extensions_dir
],
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
extra_compile_args
=
extra_compile_args
,
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
,
extra_link_args
=
extra_link_args
,
)
)
...
@@ -94,6 +108,11 @@ test_requires = [
...
@@ -94,6 +108,11 @@ test_requires = [
'pytest-cov'
,
'pytest-cov'
,
]
]
# work-around hipify abs paths
include_package_data
=
True
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
:
include_package_data
=
False
setup
(
setup
(
name
=
'torch_spline_conv'
,
name
=
'torch_spline_conv'
,
version
=
__version__
,
version
=
__version__
,
...
@@ -120,5 +139,5 @@ setup(
...
@@ -120,5 +139,5 @@ setup(
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
},
},
packages
=
find_packages
(),
packages
=
find_packages
(),
include_package_data
=
True
,
include_package_data
=
include_package_data
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment