Commit 2657fd9c authored by Dineshkumar Bhaskaran's avatar Dineshkumar Bhaskaran
Browse files

Enable ROCm builds

parent 18f48b73
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
static inline __device__ void atomAdd(float *address, float val) { static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val); atomicAdd(address, val);
} }
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000))
static inline __device__ void atomAdd(double *address, double val) { static inline __device__ void atomAdd(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *)address; unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull; unsigned long long int old = *address_as_ull;
......
...@@ -2,8 +2,12 @@ ...@@ -2,8 +2,12 @@
#include <torch/script.h> #include <torch/script.h>
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h> #include <cuda.h>
#endif #endif
#endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_CUDA #ifdef WITH_CUDA
...@@ -15,7 +19,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } ...@@ -15,7 +19,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
int64_t cuda_version() { int64_t cuda_version() {
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION; return CUDA_VERSION;
#endif
#else #else
return -1; return -1;
#endif #endif
......
...@@ -14,7 +14,10 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, ...@@ -14,7 +14,10 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
__version__ = '1.2.1' __version__ = '1.2.1'
URL = 'https://github.com/rusty1s/pytorch_spline_conv' URL = 'https://github.com/rusty1s/pytorch_spline_conv'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu'] suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1': if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu'] suffices = ['cuda', 'cpu']
...@@ -31,9 +34,12 @@ def get_extensions(): ...@@ -31,9 +34,12 @@ def get_extensions():
extensions_dir = osp.join('csrc') extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
# remove generated 'hip' files, in case of rebuilds
main_files = [path for path in main_files if 'hip' not in path]
for main, suffix in product(main_files, suffices): for main, suffix in product(main_files, suffices):
define_macros = [] define_macros = []
undef_macros = []
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows: if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare'] extra_compile_args['cxx'] += ['-Wno-sign-compare']
...@@ -59,8 +65,15 @@ def get_extensions(): ...@@ -59,8 +65,15 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2'] nvcc_flags += ['-O2']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
if torch.version.hip:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else:
nvcc_flags += ['--expt-relaxed-constexpr']
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
sources = [main] sources = [main]
...@@ -79,6 +92,7 @@ def get_extensions(): ...@@ -79,6 +92,7 @@ def get_extensions():
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
) )
...@@ -94,6 +108,11 @@ test_requires = [ ...@@ -94,6 +108,11 @@ test_requires = [
'pytest-cov', 'pytest-cov',
] ]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup( setup(
name='torch_spline_conv', name='torch_spline_conv',
version=__version__, version=__version__,
...@@ -120,5 +139,5 @@ setup( ...@@ -120,5 +139,5 @@ setup(
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False) BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
}, },
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=include_package_data,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment