Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
7c5a6b70
Unverified
Commit
7c5a6b70
authored
Oct 17, 2022
by
dkbhaskaran
Committed by
GitHub
Oct 17, 2022
Browse files
Enable ROCm build (#149)
parent
27387388
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
2 deletions
+31
-2
csrc/version.cpp
csrc/version.cpp
+8
-0
setup.py
setup.py
+23
-2
No files found.
csrc/version.cpp
View file @
7c5a6b70
...
@@ -6,8 +6,12 @@
...
@@ -6,8 +6,12 @@
#include "macros.h"
#include "macros.h"
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#include <cuda.h>
#endif
#endif
#endif
#ifdef _WIN32
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_PYTHON
...
@@ -23,7 +27,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
...
@@ -23,7 +27,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
namespace
cluster
{
namespace
cluster
{
CLUSTER_API
int64_t
cuda_version
()
noexcept
{
CLUSTER_API
int64_t
cuda_version
()
noexcept
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#ifdef USE_ROCM
return
HIP_VERSION
;
#else
return
CUDA_VERSION
;
return
CUDA_VERSION
;
#endif
#else
#else
return
-
1
;
return
-
1
;
#endif
#endif
...
...
setup.py
View file @
7c5a6b70
...
@@ -14,7 +14,10 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
...
@@ -14,7 +14,10 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
__version__
=
'1.6.0'
__version__
=
'1.6.0'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
URL
=
'https://github.com/rusty1s/pytorch_cluster'
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
WITH_CUDA
=
False
if
torch
.
cuda
.
is_available
():
WITH_CUDA
=
CUDA_HOME
is
not
None
or
torch
.
version
.
hip
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
suffices
=
[
'cpu'
,
'cuda'
]
if
WITH_CUDA
else
[
'cpu'
]
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
if
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
suffices
=
[
'cuda'
,
'cpu'
]
suffices
=
[
'cuda'
,
'cpu'
]
...
@@ -31,9 +34,12 @@ def get_extensions():
...
@@ -31,9 +34,12 @@ def get_extensions():
extensions_dir
=
osp
.
join
(
'csrc'
)
extensions_dir
=
osp
.
join
(
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
# remove generated 'hip' files, in case of rebuilds
main_files
=
[
path
for
path
in
main_files
if
'hip'
not
in
path
]
for
main
,
suffix
in
product
(
main_files
,
suffices
):
for
main
,
suffix
in
product
(
main_files
,
suffices
):
define_macros
=
[(
'WITH_PYTHON'
,
None
)]
define_macros
=
[(
'WITH_PYTHON'
,
None
)]
undef_macros
=
[]
if
sys
.
platform
==
'win32'
:
if
sys
.
platform
==
'win32'
:
define_macros
+=
[(
'torchcluster_EXPORTS'
,
None
)]
define_macros
+=
[(
'torchcluster_EXPORTS'
,
None
)]
...
@@ -63,9 +69,17 @@ def get_extensions():
...
@@ -63,9 +69,17 @@ def get_extensions():
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
,
'-O2'
]
nvcc_flags
+=
[
'-O2'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
if
torch
.
version
.
hip
:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros
+=
[(
'USE_ROCM'
,
None
)]
undef_macros
+=
[
'__HIP_NO_HALF_CONVERSIONS__'
]
else
:
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
]
sources
=
[
main
]
...
@@ -83,6 +97,7 @@ def get_extensions():
...
@@ -83,6 +97,7 @@ def get_extensions():
sources
,
sources
,
include_dirs
=
[
extensions_dir
],
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
define_macros
=
define_macros
,
undef_macros
=
undef_macros
,
extra_compile_args
=
extra_compile_args
,
extra_compile_args
=
extra_compile_args
,
extra_link_args
=
extra_link_args
,
extra_link_args
=
extra_link_args
,
)
)
...
@@ -99,6 +114,11 @@ test_requires = [
...
@@ -99,6 +114,11 @@ test_requires = [
'scipy'
,
'scipy'
,
]
]
# work-around hipify abs paths
include_package_data
=
True
if
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
:
include_package_data
=
False
setup
(
setup
(
name
=
'torch_cluster'
,
name
=
'torch_cluster'
,
version
=
__version__
,
version
=
__version__
,
...
@@ -125,4 +145,5 @@ setup(
...
@@ -125,4 +145,5 @@ setup(
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
,
use_ninja
=
False
)
},
},
packages
=
find_packages
(),
packages
=
find_packages
(),
include_package_data
=
include_package_data
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment