Unverified Commit 9a651d91 authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

enable ROCm build; add BF16 for ROCm and CUDA (#325)

* first step, everything compiles

* fix rebuilds; skip cuda version check for rocm

* use macro for __shfl_up_sync __shfl_down_sync

* add BFloat16 support for ROCm and CUDA

* add USE_ROCM definition to setup.py

* flake8 fixes
parent 18d37590
......@@ -9,3 +9,6 @@ dist/
*.aux
*.log
*.pdf
*.hip
*_hip.cpp
hip
......@@ -68,8 +68,8 @@
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
template <> struct Atomic##NAME##DecimalImpl<at::Half, 2> { \
inline __device__ void operator()(at::Half *address, at::Half val) { \
unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \
......@@ -87,6 +87,25 @@
} \
}; \
\
template <> struct Atomic##NAME##DecimalImpl<at::BFloat16, 2> { \
inline __device__ void operator()(at::BFloat16 *address, at::BFloat16 val){\
unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \
unsigned int assumed; \
\
do { \
assumed = old; \
at::BFloat16 hsum; \
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
hsum = OP(hsum, val); \
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
: (old & 0xffff0000) | hsum.x; \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
......@@ -135,7 +154,7 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) {
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000)
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000))
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
......@@ -156,6 +175,9 @@ static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
static inline __device__ void atomAdd(at::BFloat16 *address, at::BFloat16 val) {
AtomicAddDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) Y *X
ATOMIC(Mul)
......@@ -184,6 +206,9 @@ static inline __device__ void atomMul(at::Half *address, at::Half val) {
static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomMul(at::BFloat16 *address, at::BFloat16 val) {
AtomicMulDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) Y / X
ATOMIC(Div)
......@@ -212,6 +237,9 @@ static inline __device__ void atomDiv(float *address, float val) {
static inline __device__ void atomDiv(double *address, double val) {
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomDiv(at::BFloat16 *address, at::BFloat16 val) {
AtomicDivDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) max(Y, X)
ATOMIC(Max)
......@@ -240,6 +268,9 @@ static inline __device__ void atomMax(float *address, float val) {
static inline __device__ void atomMax(double *address, double val) {
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomMax(at::BFloat16 *address, at::BFloat16 val) {
AtomicMaxDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) min(Y, X)
ATOMIC(Min)
......@@ -268,3 +299,6 @@ static inline __device__ void atomMin(float *address, float val) {
static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomMin(at::BFloat16 *address, at::BFloat16 val) {
AtomicMinDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
......@@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -36,8 +36,8 @@ segment_coo_kernel(const scalar_t *src_data,
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
tmp = SHFL_UP_SYNC(FULL_MASK, val, i);
next_idx = SHFL_UP_SYNC(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx);
if (idx == next_idx)
......@@ -45,7 +45,7 @@ segment_coo_kernel(const scalar_t *src_data,
}
}
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
next_idx = SHFL_DOWN_SYNC(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
idx != next_idx)
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
......@@ -214,7 +214,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......@@ -365,7 +365,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -46,9 +46,9 @@ segment_csr_kernel(const scalar_t *src_data,
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
&val, SHFL_DOWN_SYNC(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
......@@ -147,7 +147,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......@@ -264,7 +264,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -17,3 +17,14 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const unsigned int delta) {
return __shfl_down_sync(mask, var.operator __half(), delta);
}
#ifdef USE_ROCM
__device__ __inline__ at::Half __ldg(const at::Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr));
}
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
#else
#define SHFL_UP_SYNC __shfl_up_sync
#define SHFL_DOWN_SYNC __shfl_down_sync
#endif
......@@ -7,8 +7,12 @@
#include "macros.h"
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#endif
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
......@@ -23,7 +27,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
namespace scatter {
SCATTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION;
#endif
#else
return -1;
#endif
......
......@@ -14,7 +14,9 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
__version__ = '2.0.9'
URL = 'https://github.com/rusty1s/pytorch_scatter'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu']
......@@ -32,9 +34,12 @@ def get_extensions():
extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
# remove generated 'hip' files, in case of rebuilds
main_files = [path for path in main_files if 'hip' not in path]
for main, suffix in product(main_files, suffices):
define_macros = [('WITH_PYTHON', None)]
undef_macros = []
if sys.platform == 'win32':
define_macros += [('torchscatter_EXPORTS', None)]
......@@ -64,7 +69,14 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
if torch.version.hip:
nvcc_flags += ['-O3']
# USE_ROCM was added to later versons of rocm pytorch
# define here to support older pytorch versions
define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else:
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
extra_compile_args['nvcc'] = nvcc_flags
name = main.split(os.sep)[-1][:-4]
......@@ -84,6 +96,7 @@ def get_extensions():
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
......@@ -99,6 +112,11 @@ test_requires = [
'pytest-cov',
]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup(
name='torch_scatter',
version=__version__,
......@@ -119,5 +137,5 @@ setup(
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
include_package_data=True,
include_package_data=include_package_data,
)
......@@ -47,7 +47,9 @@ for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
torch.ops.torch_scatter.gather_coo = gather_coo_placeholder
cuda_version = torch.ops.torch_scatter.cuda_version()
if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
is_not_hip = torch.version.hip is None
is_cuda = torch.version.cuda is not None
if is_not_hip and is_cuda and cuda_version != -1: # pragma: no cover
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment