Unverified Commit 31388573 authored by dkbhaskaran's avatar dkbhaskaran Committed by GitHub
Browse files

Enable ROCm builds (#282)

parent d1aee184
......@@ -5,7 +5,7 @@ static inline __device__ void atomAdd(float *address, float val) {
}
static inline __device__ void atomAdd(double *address, double val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000))
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
......
......@@ -16,3 +16,14 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const unsigned int delta) {
return __shfl_down_sync(mask, var.operator __half(), delta);
}
#ifdef USE_ROCM
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr));
}
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
#else
#define SHFL_UP_SYNC __shfl_up_sync
#define SHFL_DOWN_SYNC __shfl_down_sync
#endif
......@@ -4,8 +4,12 @@
#include <torch/script.h>
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#endif
#endif
#include "macros.h"
......@@ -22,7 +26,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
namespace sparse {
SPARSE_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION;
#endif
#else
return -1;
#endif
......
......@@ -18,7 +18,9 @@ from torch.utils.cpp_extension import (
__version__ = '0.6.15'
URL = 'https://github.com/rusty1s/pytorch_sparse'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu']
......@@ -40,9 +42,12 @@ def get_extensions():
extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
# remove generated 'hip' files, in case of rebuilds
main_files = [path for path in main_files if 'hip' not in path]
for main, suffix in product(main_files, suffices):
define_macros = [('WITH_PYTHON', None)]
undef_macros = []
if sys.platform == 'win32':
define_macros += [('torchsparse_EXPORTS', None)]
......@@ -84,13 +89,26 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
nvcc_flags += ['-O2']
extra_compile_args['nvcc'] = nvcc_flags
if torch.version.hip:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else:
nvcc_flags += ['--expt-relaxed-constexpr']
if sys.platform == 'win32':
extra_link_args += ['cusparse.lib']
if torch.version.hip:
if sys.platform == 'win32':
extra_link_args += ['hipsparse.lib']
else:
extra_link_args += ['-lhipsparse', '-l', 'hipsparse']
else:
extra_link_args += ['-lcusparse', '-l', 'cusparse']
if sys.platform == 'win32':
extra_link_args += ['cusparse.lib']
else:
extra_link_args += ['-lcusparse', '-l', 'cusparse']
name = main.split(os.sep)[-1][:-4]
sources = [main]
......@@ -111,6 +129,7 @@ def get_extensions():
sources,
include_dirs=[extensions_dir, phmap_dir],
define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
libraries=libraries,
......@@ -129,6 +148,11 @@ test_requires = [
'pytest-cov',
]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup(
name='torch_sparse',
version=__version__,
......@@ -155,5 +179,5 @@ setup(
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
include_package_data=True,
include_package_data=include_package_data,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment