Unverified Commit 1f49a3a5 authored by Jacob Zhong's avatar Jacob Zhong Committed by GitHub
Browse files

Add support for half (#137)



* compile with half

* Fix

* fix

* rename

* update

* update

* update

* update

* typo
Co-authored-by: default avatarrusty1s <matthias.fey@tu-dortmund.de>
parent eda8dc62
......@@ -72,7 +72,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
......
......@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t *count_data = nullptr;
......@@ -130,7 +130,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN)
arg_out.value().clamp_(1);
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
(scalar_t)1);
});
});
......@@ -177,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -68,6 +68,25 @@
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \
unsigned int assumed; \
\
do { \
assumed = old; \
at::Half hsum; \
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
hsum = OP(hsum, val); \
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
: (old & 0xffff0000) | hsum.x; \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
......@@ -116,6 +135,15 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) {
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000)
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
#else
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
atomicAdd(reinterpret_cast<__half *>(address), val);
}
#endif
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
......@@ -150,6 +178,9 @@ static inline __device__ void atomMul(int64_t *address, int64_t val) {
static inline __device__ void atomMul(float *address, float val) {
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMul(at::Half *address, at::Half val) {
AtomicMulDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}
......@@ -172,6 +203,9 @@ static inline __device__ void atomDiv(int32_t *address, int32_t val) {
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomDiv(at::Half *address, at::Half val) {
AtomicDivDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomDiv(float *address, float val) {
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
}
......@@ -197,6 +231,9 @@ static inline __device__ void atomMax(int32_t *address, int32_t val) {
static inline __device__ void atomMax(int64_t *address, int64_t val) {
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMax(at::Half *address, at::Half val) {
AtomicMaxDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMax(float *address, float val) {
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
}
......@@ -222,6 +259,9 @@ static inline __device__ void atomMin(int32_t *address, int32_t val) {
static inline __device__ void atomMin(int64_t *address, int64_t val) {
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMin(at::Half *address, at::Half val) {
AtomicMinDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMin(float *address, float val) {
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
}
......
......@@ -89,7 +89,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
......
......@@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <type_traits>
#include "reducer.cuh"
#include "utils.cuh"
......@@ -25,6 +26,10 @@ segment_coo_kernel(const scalar_t *src_data,
int lane_idx = row_idx & (32 - 1);
int D = index_info.sizes[index_info.dims - 1];
using cuda_scalar_t =
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
scalar_t>::type;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
......@@ -36,7 +41,7 @@ segment_coo_kernel(const scalar_t *src_data,
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx);
......@@ -214,7 +219,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......@@ -266,14 +271,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
arg_out.value().clamp_(1);
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
(scalar_t)1);
auto count = arg_out.value();
for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1);
if (out.is_floating_point())
out.true_divide_(count);
out.div_(count);
else
out.floor_divide_(count);
out.div_(count, "floor");
}
});
});
......@@ -364,7 +370,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -26,6 +26,10 @@ segment_csr_kernel(const scalar_t *src_data,
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
using cuda_scalar_t =
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
scalar_t>::type;
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
......@@ -48,7 +52,8 @@ segment_csr_kernel(const scalar_t *src_data,
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
&val, __shfl_down_sync(FULL_MASK, (cuda_scalar_t)val, i), &arg,
arg_tmp);
}
if (lane_idx == 0) {
......@@ -147,7 +152,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......@@ -264,7 +269,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -127,13 +127,12 @@ public:
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.clamp_(1);
count.masked_fill_(count < 1, 1);
count = broadcast(count, out, dim);
if (out.is_floating_point())
out.true_divide_(count);
out.div_(count);
else
out.floor_divide_(count);
out.div_(count, "floor");
ctx->save_for_backward({index, count});
if (optional_out.has_value())
......
......@@ -2,7 +2,7 @@ import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.float, torch.double, torch.int, torch.long]
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')]
......
......@@ -50,12 +50,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count.clamp_(1)
count[count < 1] = 1
count = broadcast(count, out, dim)
if torch.is_floating_point(out):
out.true_divide_(count)
else:
out.floor_divide_(count)
rounding_mode = None if torch.is_floating_point(out) else 'floor'
out.div_(count, rounding_mode=rounding_mode)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment