Commit 6fca568d authored by rusty1s's avatar rusty1s
Browse files

override shfl methods for torch.half

parent 66bcc36e
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <type_traits>
#include "reducer.cuh" #include "reducer.cuh"
#include "utils.cuh" #include "utils.cuh"
...@@ -26,10 +25,6 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -26,10 +25,6 @@ segment_coo_kernel(const scalar_t *src_data,
int lane_idx = row_idx & (32 - 1); int lane_idx = row_idx & (32 - 1);
int D = index_info.sizes[index_info.dims - 1]; int D = index_info.sizes[index_info.dims - 1];
using cuda_scalar_t =
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
scalar_t>::type;
if (row_idx < E) { if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info); row_idx, index_info);
...@@ -41,7 +36,7 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -41,7 +36,7 @@ segment_coo_kernel(const scalar_t *src_data,
#pragma unroll #pragma unroll
for (int i = 1; i < 32; i *= 2) { for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp. // Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i); tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) { if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx); assert(idx >= next_idx);
......
...@@ -26,10 +26,6 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -26,10 +26,6 @@ segment_csr_kernel(const scalar_t *src_data,
int row_idx = thread_idx / TB; int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1); int lane_idx = thread_idx & (TB - 1);
using cuda_scalar_t =
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
scalar_t>::type;
if (row_idx < N) { if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info); int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset); int64_t row_start = __ldg(indptr_info.data + offset);
...@@ -52,8 +48,7 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -52,8 +48,7 @@ segment_csr_kernel(const scalar_t *src_data,
if (REDUCE == MIN || REDUCE == MAX) if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, (cuda_scalar_t)val, i), &arg, &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
arg_tmp);
} }
if (lane_idx == 0) { if (lane_idx == 0) {
......
...@@ -5,3 +5,15 @@ ...@@ -5,3 +5,15 @@
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_up_sync(mask, (__half)var, delta);
}
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_down_sync(mask, (__half)var, delta);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment