Unverified Commit 9a651d91 authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

enable ROCm build; add BF16 for ROCm and CUDA (#325)

* first step, everything compiles

* fix rebuilds; skip cuda version check for rocm

* use macro for __shfl_up_sync __shfl_down_sync

* add BFloat16 support for ROCm and CUDA

* add USE_ROCM definition to setup.py

* flake8 fixes
parent 18d37590
...@@ -9,3 +9,6 @@ dist/ ...@@ -9,3 +9,6 @@ dist/
*.aux *.aux
*.log *.log
*.pdf *.pdf
*.hip
*_hip.cpp
hip
...@@ -68,8 +68,8 @@ ...@@ -68,8 +68,8 @@
\ \
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \ template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\ \
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 2> { \ template <> struct Atomic##NAME##DecimalImpl<at::Half, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \ inline __device__ void operator()(at::Half *address, at::Half val) { \
unsigned int *address_as_ui = \ unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \ (unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \ unsigned int old = *address_as_ui; \
...@@ -87,6 +87,25 @@ ...@@ -87,6 +87,25 @@
} \ } \
}; \ }; \
\ \
template <> struct Atomic##NAME##DecimalImpl<at::BFloat16, 2> { \
inline __device__ void operator()(at::BFloat16 *address, at::BFloat16 val){\
unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \
unsigned int assumed; \
\
do { \
assumed = old; \
at::BFloat16 hsum; \
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
hsum = OP(hsum, val); \
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
: (old & 0xffff0000) | hsum.x; \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \ template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \ inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \ int *address_as_i = (int *)address; \
...@@ -135,7 +154,7 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) { ...@@ -135,7 +154,7 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) {
static inline __device__ void atomAdd(int64_t *address, int64_t val) { static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
} }
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000) #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000))
static inline __device__ void atomAdd(at::Half *address, at::Half val) { static inline __device__ void atomAdd(at::Half *address, at::Half val) {
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val); AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
} }
...@@ -156,6 +175,9 @@ static inline __device__ void atomAdd(double *address, double val) { ...@@ -156,6 +175,9 @@ static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val); atomicAdd(address, val);
} }
#endif #endif
static inline __device__ void atomAdd(at::BFloat16 *address, at::BFloat16 val) {
AtomicAddDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) Y *X #define OP(X, Y) Y *X
ATOMIC(Mul) ATOMIC(Mul)
...@@ -184,6 +206,9 @@ static inline __device__ void atomMul(at::Half *address, at::Half val) { ...@@ -184,6 +206,9 @@ static inline __device__ void atomMul(at::Half *address, at::Half val) {
static inline __device__ void atomMul(double *address, double val) { static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val); AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
} }
static inline __device__ void atomMul(at::BFloat16 *address, at::BFloat16 val) {
AtomicMulDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) Y / X #define OP(X, Y) Y / X
ATOMIC(Div) ATOMIC(Div)
...@@ -212,6 +237,9 @@ static inline __device__ void atomDiv(float *address, float val) { ...@@ -212,6 +237,9 @@ static inline __device__ void atomDiv(float *address, float val) {
static inline __device__ void atomDiv(double *address, double val) { static inline __device__ void atomDiv(double *address, double val) {
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val); AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
} }
static inline __device__ void atomDiv(at::BFloat16 *address, at::BFloat16 val) {
AtomicDivDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) max(Y, X) #define OP(X, Y) max(Y, X)
ATOMIC(Max) ATOMIC(Max)
...@@ -240,6 +268,9 @@ static inline __device__ void atomMax(float *address, float val) { ...@@ -240,6 +268,9 @@ static inline __device__ void atomMax(float *address, float val) {
static inline __device__ void atomMax(double *address, double val) { static inline __device__ void atomMax(double *address, double val) {
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val); AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
} }
static inline __device__ void atomMax(at::BFloat16 *address, at::BFloat16 val) {
AtomicMaxDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) min(Y, X) #define OP(X, Y) min(Y, X)
ATOMIC(Min) ATOMIC(Min)
...@@ -268,3 +299,6 @@ static inline __device__ void atomMin(float *address, float val) { ...@@ -268,3 +299,6 @@ static inline __device__ void atomMin(float *address, float val) {
static inline __device__ void atomMin(double *address, double val) { static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val); AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
} }
static inline __device__ void atomMin(at::BFloat16 *address, at::BFloat16 val) {
AtomicMinDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
...@@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
...@@ -36,8 +36,8 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -36,8 +36,8 @@ segment_coo_kernel(const scalar_t *src_data,
#pragma unroll #pragma unroll
for (int i = 1; i < 32; i *= 2) { for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp. // Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i); tmp = SHFL_UP_SYNC(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i); next_idx = SHFL_UP_SYNC(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx); assert(idx >= next_idx);
if (idx == next_idx) if (idx == next_idx)
...@@ -45,7 +45,7 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -45,7 +45,7 @@ segment_coo_kernel(const scalar_t *src_data,
} }
} }
next_idx = __shfl_down_sync(FULL_MASK, idx, 1); next_idx = SHFL_DOWN_SYNC(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D || if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
idx != next_idx) idx != next_idx)
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val); Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
...@@ -214,7 +214,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -214,7 +214,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
...@@ -365,7 +365,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -365,7 +365,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
...@@ -46,9 +46,9 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -46,9 +46,9 @@ segment_csr_kernel(const scalar_t *src_data,
for (int i = TB / 2; i > 0; i /= 2) { for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp. // Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX) if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); &val, SHFL_DOWN_SYNC(FULL_MASK, val, i), &arg, arg_tmp);
} }
if (lane_idx == 0) { if (lane_idx == 0) {
...@@ -147,7 +147,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -147,7 +147,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr); auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
...@@ -264,7 +264,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -264,7 +264,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr); auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
...@@ -17,3 +17,14 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, ...@@ -17,3 +17,14 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const unsigned int delta) { const unsigned int delta) {
return __shfl_down_sync(mask, var.operator __half(), delta); return __shfl_down_sync(mask, var.operator __half(), delta);
} }
#ifdef USE_ROCM
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr));
}
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
#else
#define SHFL_UP_SYNC __shfl_up_sync
#define SHFL_DOWN_SYNC __shfl_down_sync
#endif
...@@ -7,8 +7,12 @@ ...@@ -7,8 +7,12 @@
#include "macros.h" #include "macros.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h> #include <cuda.h>
#endif #endif
#endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON #ifdef WITH_PYTHON
...@@ -23,7 +27,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } ...@@ -23,7 +27,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
namespace scatter { namespace scatter {
SCATTER_API int64_t cuda_version() noexcept { SCATTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA #ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION; return CUDA_VERSION;
#endif
#else #else
return -1; return -1;
#endif #endif
......
...@@ -14,7 +14,9 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, ...@@ -14,7 +14,9 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
__version__ = '2.0.9' __version__ = '2.0.9'
URL = 'https://github.com/rusty1s/pytorch_scatter' URL = 'https://github.com/rusty1s/pytorch_scatter'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu'] suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1': if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu'] suffices = ['cuda', 'cpu']
...@@ -32,9 +34,12 @@ def get_extensions(): ...@@ -32,9 +34,12 @@ def get_extensions():
extensions_dir = osp.join('csrc') extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
# remove generated 'hip' files, in case of rebuilds
main_files = [path for path in main_files if 'hip' not in path]
for main, suffix in product(main_files, suffices): for main, suffix in product(main_files, suffices):
define_macros = [('WITH_PYTHON', None)] define_macros = [('WITH_PYTHON', None)]
undef_macros = []
if sys.platform == 'win32': if sys.platform == 'win32':
define_macros += [('torchscatter_EXPORTS', None)] define_macros += [('torchscatter_EXPORTS', None)]
...@@ -64,7 +69,14 @@ def get_extensions(): ...@@ -64,7 +69,14 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2'] if torch.version.hip:
nvcc_flags += ['-O3']
# USE_ROCM was added to later versons of rocm pytorch
# define here to support older pytorch versions
define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else:
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
...@@ -84,6 +96,7 @@ def get_extensions(): ...@@ -84,6 +96,7 @@ def get_extensions():
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
) )
...@@ -99,6 +112,11 @@ test_requires = [ ...@@ -99,6 +112,11 @@ test_requires = [
'pytest-cov', 'pytest-cov',
] ]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup( setup(
name='torch_scatter', name='torch_scatter',
version=__version__, version=__version__,
...@@ -119,5 +137,5 @@ setup( ...@@ -119,5 +137,5 @@ setup(
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False) BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
}, },
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=include_package_data,
) )
...@@ -47,7 +47,9 @@ for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']: ...@@ -47,7 +47,9 @@ for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
torch.ops.torch_scatter.gather_coo = gather_coo_placeholder torch.ops.torch_scatter.gather_coo = gather_coo_placeholder
cuda_version = torch.ops.torch_scatter.cuda_version() cuda_version = torch.ops.torch_scatter.cuda_version()
if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover is_not_hip = torch.version.hip is None
is_cuda = torch.version.cuda is not None
if is_not_hip and is_cuda and cuda_version != -1: # pragma: no cover
if cuda_version < 10000: if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2]) major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment