#include #include #include #include #include "atomics.cuh" #include "compat.cuh" #include "indptr.cuh" #define THREADS 256 #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define FULL_MASK 0xffffffff enum ReductionType { ADD, MEAN, MIN, MAX }; #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ [&] { \ if (reduce == "add") { \ const ReductionType REDUCE = ADD; \ return __VA_ARGS__(); \ } else if (reduce == "mean") { \ const ReductionType REDUCE = MEAN; \ return __VA_ARGS__(); \ } else if (reduce == "min") { \ const ReductionType REDUCE = MIN; \ return __VA_ARGS__(); \ } else if (reduce == "max") { \ const ReductionType REDUCE = MAX; \ return __VA_ARGS__(); \ } \ }() template struct Reducer { static inline __host__ __device__ scalar_t init() { if (REDUCE == MIN) { return std::numeric_limits::max(); } else if (REDUCE == MAX) { return std::numeric_limits::lowest(); } else { return (scalar_t)0; } } static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val) { if (REDUCE == ADD || REDUCE == MEAN) { *val = *val + new_val; } else if ((REDUCE == MIN && new_val < *val) || (REDUCE == MAX && new_val > *val)) { *val = new_val; } } static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val, int64_t *arg, int64_t new_arg) { if (REDUCE == ADD || REDUCE == MEAN) { *val = *val + new_val; } else if ((REDUCE == MIN && new_val < *val) || (REDUCE == MAX && new_val > *val)) { *val = new_val; *arg = new_arg; } } static inline __host__ __device__ void write(scalar_t *address, scalar_t val, int64_t *arg_address, int64_t arg, int count) { if (REDUCE == ADD) { *address = val; } else if (REDUCE == MEAN) { *address = val / (scalar_t)max(count, 1); } else if (REDUCE == MIN || REDUCE == MAX) { if (count > 0) { *address = val; *arg_address = arg; } else { *address = (scalar_t)0; } } } static inline __device__ void atomic_write(scalar_t *address, scalar_t val) { if (REDUCE == ADD) { atomAdd(address, val); } else if (REDUCE == MEAN) { atomAdd(address, val); } else if (REDUCE == MIN && val < *address) { atomMin(address, val); } else if (REDUCE == MAX && val > *address) { atomMax(address, val); } } }; template __global__ void segment_csr_kernel(const scalar_t *src_data, const at::cuda::detail::TensorInfo indptr_info, scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t E) { // Each warp processes exactly `32/TB` rows and aggregates all row values // via a parallel reduction. int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = thread_idx / TB; int lane_idx = thread_idx & (TB - 1); if (row_idx < N) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int row_start = __ldg(indptr_info.data + offset); int row_end = __ldg(indptr_info.data + offset + indptr_info.strides[indptr_info.dims - 1]); scalar_t val = Reducer::init(); int64_t arg, arg_tmp; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) { Reducer::update(&val, src_data[offset + src_idx], &arg, src_idx); } #pragma unroll for (int i = TB / 2; i > 0; i /= 2) { // Parallel reduction inside a single warp. if (REDUCE == MIN || REDUCE == MAX) arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); Reducer::update( &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); } if (lane_idx == 0) { Reducer::write(out_data + row_idx, val, arg_out_data + row_idx, arg, row_end - row_start); } } } template __global__ void segment_csr_broadcast_kernel( const scalar_t *src_data, const at::cuda::detail::TensorInfo indptr_info, scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) { // Each thread processes exactly one row. It turned out that is more // efficient than using shared memory due to avoiding synchronization // barriers. int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = thread_idx / K; int lane_idx = thread_idx % K; if (thread_idx < N * K) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int row_start = __ldg(indptr_info.data + offset); int row_end = __ldg(indptr_info.data + offset + indptr_info.strides[indptr_info.dims - 1]); scalar_t val = Reducer::init(); int64_t arg; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; for (int src_idx = row_start; src_idx < row_end; src_idx++) { Reducer::update( &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx); } Reducer::write(out_data + thread_idx, val, arg_out_data + thread_idx, arg, row_end - row_start); } } std::tuple> segment_csr_cuda(at::Tensor src, at::Tensor indptr, at::optional out_opt, std::string reduce) { AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); for (int i = 0; i < indptr.dim() - 1; i++) AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch"); src = src.contiguous(); auto reduce_dim = indptr.dim() - 1; at::Tensor out; if (out_opt.has_value()) { out = out_opt.value().contiguous(); for (int i = 0; i < out.dim(); i++) if (i != reduce_dim) AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1, "Input mismatch"); } else { auto sizes = src.sizes().vec(); sizes[reduce_dim] = indptr.size(reduce_dim) - 1; out = at::empty(sizes, src.options()); } at::optional arg_out = at::nullopt; int64_t *arg_out_data = nullptr; if (reduce == "min" || reduce == "max") { arg_out = at::full_like(out, src.size(reduce_dim), indptr.options()); arg_out_data = arg_out.value().DATA_PTR(); } if (reduce == "any") { auto index = indptr.narrow(reduce_dim, 0, indptr.size(reduce_dim) - 1); auto index2 = indptr.narrow(reduce_dim, 1, indptr.size(reduce_dim) - 1); auto mask = (index2 - index) == 0; for (int i = reduce_dim + 1; i < src.dim(); i++) { index = index.unsqueeze(-1); mask = mask.unsqueeze(-1); } at::gather_out(out, src, reduce_dim, index.expand(out.sizes())); out.masked_fill_(mask.expand(out.sizes()), 0); return std::make_tuple(out, arg_out); } auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1)); auto K = out.numel() / N; auto E = src.size(reduce_dim); auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] { auto src_data = src.DATA_PTR(); auto out_data = out.DATA_PTR(); AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { if (K == 1) { segment_csr_kernel <<>>( src_data, indptr_info, out_data, arg_out_data, N, E); } else { segment_csr_broadcast_kernel <<>>( src_data, indptr_info, out_data, arg_out_data, N, K, E); } }); }); return std::make_tuple(out, arg_out); } template __global__ void segment_coo_kernel(const scalar_t *src_data, const at::cuda::detail::TensorInfo index_info, scalar_t *out_data, size_t E, size_t N) { // Each thread processes exactly one entry. Within a warp, we perform a // parallel reduction across equal indices, and write the intermediate // result via atomics. int row_idx = blockIdx.x * blockDim.x + threadIdx.x; int lane_idx = row_idx & (32 - 1); if (row_idx < E) { int offset = at::cuda::detail::IndexToOffset::get( row_idx, index_info); int idx = index_info.data[offset], next_idx; int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx; scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; #pragma unroll for (int i = 1; i < 32; i *= 2) { // Parallel reduction inside a single warp. tmp = __shfl_up_sync(FULL_MASK, val, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i); assert(idx >= next_idx); if (lane_idx >= i && idx == next_idx) Reducer::update(&val, tmp); } next_idx = __shfl_down_sync(FULL_MASK, idx, 1); if (lane_idx == 32 - 1 || idx != next_idx) { Reducer::atomic_write(out_data + out_idx, val); } } } template __global__ void segment_coo_arg_kernel( const scalar_t *src_data, const at::cuda::detail::TensorInfo index_info, scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) { int row_idx = blockIdx.x * blockDim.x + threadIdx.x; if (row_idx < E) { int offset = at::cuda::detail::IndexToOffset::get( row_idx, index_info); int idx = index_info.data[offset]; int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx; scalar_t val = __ldg(out_data + out_idx); if (src_data[row_idx] == val) arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1]; } } template __global__ void segment_coo_broadcast_kernel( const scalar_t *src_data, const at::cuda::detail::TensorInfo index_info, scalar_t *out_data, size_t E, size_t K, size_t N) { // Each thread processes a single column and `TB` index entries. Coalesced // read and write is performed in column-major order. The intermediate // results are written via atomics. int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB; int col_idx = blockIdx.y * blockDim.x + threadIdx.x; if (row_start < E && col_idx < K) { int offset = at::cuda::detail::IndexToOffset::get( row_start, index_info); int out_idx = (row_start / index_info.sizes[index_info.dims - 1]) * N; int idx1 = __ldg(index_info.data + offset); scalar_t val = src_data[K * row_start + col_idx]; #pragma unroll for (int i = 1; i < TB; i++) { if (row_start + i >= E) break; int idx2 = __ldg(index_info.data + offset + i * index_info.strides[index_info.dims - 1]); assert(idx1 <= idx2); if (idx1 == idx2) { Reducer::update( &val, src_data[K * (row_start + i) + col_idx]); } else { Reducer::atomic_write( out_data + (out_idx + idx1) * K + col_idx, val); val = src_data[K * (row_start + i) + col_idx]; } idx1 = idx2; } Reducer::atomic_write( out_data + (out_idx + idx1) * K + col_idx, val); } } template __global__ void segment_coo_arg_broadcast_kernel( const scalar_t *src_data, const at::cuda::detail::TensorInfo index_info, scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) { int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = thread_idx / K; int col_idx = thread_idx % K; if (row_idx < E && col_idx < K) { int offset = at::cuda::detail::IndexToOffset::get( row_idx, index_info); int idx = __ldg(index_info.data + offset); int out_idx = ((row_idx / index_info.sizes[index_info.dims - 1]) * N + idx) * K + col_idx; scalar_t val = __ldg(out_data + out_idx); if (src_data[thread_idx] == val) arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1]; } } std::tuple> segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, std::string reduce) { AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); for (int i = 0; i < index.dim(); i++) AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch"); src = src.contiguous(); out = out.contiguous(); auto reduce_dim = index.dim() - 1; for (int i = 0; i < out.dim(); i++) if (i != reduce_dim) AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); at::optional arg_out = at::nullopt; int64_t *arg_out_data = nullptr; if (reduce == "min" || reduce == "max") { arg_out = at::full_like(out, src.size(reduce_dim), index.options()); arg_out_data = arg_out.value().DATA_PTR(); } if (reduce == "any") { for (int i = reduce_dim + 1; i < src.dim(); i++) { index = index.unsqueeze(-1); } out.scatter_(reduce_dim, index.expand(src.sizes()), src); return std::make_tuple(out, arg_out); } auto E = index.numel(); auto K = src.numel() / E; auto N = out.size(reduce_dim); auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim); auto index_info = at::cuda::detail::getTensorInfo(index); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { auto src_data = src.DATA_PTR(); auto out_data = out.DATA_PTR(); AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { if (K == 1) { segment_coo_kernel <<>>(src_data, index_info, out_data, E, N); } else if (avg_len <= 8) { segment_coo_broadcast_kernel <<>>(src_data, index_info, out_data, E, K, N); } else if (avg_len <= 16) { segment_coo_broadcast_kernel <<>>(src_data, index_info, out_data, E, K, N); } else if (avg_len <= 32) { segment_coo_broadcast_kernel <<>>(src_data, index_info, out_data, E, K, N); } else { segment_coo_broadcast_kernel <<>>(src_data, index_info, out_data, E, K, N); } if (REDUCE == MIN || REDUCE == MAX) { if (K == 1) { segment_coo_arg_kernel <<>>( src_data, index_info, out_data, arg_out_data, E, N); } else { segment_coo_arg_broadcast_kernel <<>>( src_data, index_info, out_data, arg_out_data, E, K, N); } } }); }); if (reduce == "mean") { auto sizes = index.sizes().vec(); sizes[reduce_dim] = out.size(reduce_dim); auto count = at::zeros(sizes, out.options()); AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] { auto count_data = count.DATA_PTR(); segment_coo_kernel <<>>(nullptr, index_info, count_data, E, N); }); count.clamp_(1); arg_out = count; for (int i = reduce_dim + 1; i < out.dim(); i++) { count = count.unsqueeze(-1); } out.div_(count); } return std::make_tuple(out, arg_out); }