#include "segment_csr_cuda.h" #include #include #include #include "index_info.cuh" #include "reducer.cuh" #include "utils.cuh" #define THREADS 256 #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS #define FULL_MASK 0xffffffff template __global__ void segment_csr_kernel(const scalar_t *src_data, const at::cuda::detail::TensorInfo indptr_info, scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t E) { // Each warp processes exactly `32/TB` rows and aggregates all row values // via a parallel reduction. int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = thread_idx / TB; int lane_idx = thread_idx & (TB - 1); if (row_idx < N) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int64_t row_start = __ldg(indptr_info.data + offset); int64_t row_end = __ldg(indptr_info.data + offset + indptr_info.strides[indptr_info.dims - 1]); scalar_t val = Reducer::init(); int64_t arg, arg_tmp; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; for (int64_t src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) { Reducer::update(&val, src_data[offset + src_idx], &arg, src_idx); } #pragma unroll for (int i = TB / 2; i > 0; i /= 2) { // Parallel reduction inside a single warp. if (REDUCE == MIN || REDUCE == MAX) arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); Reducer::update( &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); } if (lane_idx == 0) { Reducer::write(out_data + row_idx, val, arg_out_data + row_idx, arg, row_end - row_start); } } } template __global__ void segment_csr_broadcast_kernel( const scalar_t *src_data, const at::cuda::detail::TensorInfo indptr_info, scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) { // Each thread processes exactly one row. It turned out that is more // efficient than using shared memory due to avoiding synchronization // barriers. int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = thread_idx / K; int lane_idx = thread_idx % K; if (thread_idx < N * K) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int64_t row_start = __ldg(indptr_info.data + offset); int64_t row_end = __ldg(indptr_info.data + offset + indptr_info.strides[indptr_info.dims - 1]); scalar_t val = Reducer::init(); int64_t arg; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) { Reducer::update( &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx); } Reducer::write(out_data + thread_idx, val, arg_out_data + thread_idx, arg, row_end - row_start); } } std::tuple> segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out, std::string reduce) { CHECK_CUDA(src); CHECK_CUDA(indptr); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); cudaSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= indptr.dim()); auto sizes = indptr.sizes().vec(); for (auto i = 0; i < indptr.dim() - 1; i++) sizes[i] = src.size(i); indptr = indptr.expand(sizes); auto dim = indptr.dim() - 1; src = src.contiguous(); torch::Tensor out; if (optional_out.has_value()) { out = optional_out.value().contiguous(); for (int i = 0; i < out.dim(); i++) if (i != dim) CHECK_INPUT(src.size(i) == out.size(i)); CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1); } else { sizes = src.sizes().vec(); sizes[dim] = std::max(indptr.size(dim) - 1, 0); out = torch::empty(sizes, src.options()); } torch::optional arg_out = torch::nullopt; int64_t *arg_out_data = nullptr; if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { arg_out = torch::full(out.sizes(), src.size(dim), indptr.options()); arg_out_data = arg_out.value().data_ptr(); } if (src.numel() == 0) { if (!optional_out.has_value()) out.fill_(0); return std::make_tuple(out, arg_out); } auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); auto K = out.numel() / N; auto E = src.size(dim); auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { if (K == 1) { segment_csr_kernel <<>>( src_data, indptr_info, out_data, arg_out_data, N, E); } else { segment_csr_broadcast_kernel <<>>( src_data, indptr_info, out_data, arg_out_data, N, K, E); } }); }); return std::make_tuple(out, arg_out); } template __global__ void gather_csr_kernel(const scalar_t *src_data, const at::cuda::detail::TensorInfo indptr_info, scalar_t *out_data, size_t N, size_t E) { int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = thread_idx / TB; int lane_idx = thread_idx % TB; if (row_idx < N) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int row_start = __ldg(indptr_info.data + offset); int row_end = __ldg(indptr_info.data + offset + indptr_info.strides[indptr_info.dims - 1]); scalar_t val = __ldg(src_data + row_idx); offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) { out_data[offset + out_idx] = val; // "Mostly" coalesced. } } } template __global__ void gather_csr_broadcast_kernel( const scalar_t *src_data, const at::cuda::detail::TensorInfo indptr_info, scalar_t *out_data, size_t N, size_t K, size_t E) { int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int row_idx = thread_idx / K; int lane_idx = thread_idx % K; if (thread_idx < N * K) { int offset = IndexPtrToOffset::get(row_idx, indptr_info); int row_start = __ldg(indptr_info.data + offset); int row_end = __ldg(indptr_info.data + offset + indptr_info.strides[indptr_info.dims - 1]); scalar_t val = src_data[thread_idx]; // Coalesced. offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; for (int out_idx = row_start; out_idx < row_end; out_idx++) { out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced. } } } torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, torch::optional optional_out) { CHECK_CUDA(src); CHECK_CUDA(indptr); if (optional_out.has_value()) CHECK_CUDA(optional_out.value()); cudaSetDevice(src.get_device()); CHECK_INPUT(src.dim() >= indptr.dim()); auto sizes = indptr.sizes().vec(); for (auto i = 0; i < indptr.dim() - 1; i++) sizes[i] = src.size(i); indptr = indptr.expand(sizes); auto dim = indptr.dim() - 1; CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1); src = src.contiguous(); torch::Tensor out; if (optional_out.has_value()) { out = optional_out.value().contiguous(); for (auto i = 0; i < out.dim(); i++) if (i != dim) CHECK_INPUT(src.size(i) == out.size(i)); } else { auto sizes = src.sizes().vec(); if (src.numel() > 0) { auto d_size = indptr.flatten()[-1].data_ptr(); auto h_size = (int64_t *)malloc(sizeof(int64_t)); cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost); sizes[dim] = *h_size; } else sizes[dim] = 0; out = torch::empty(sizes, src.options()); } if (src.numel() == 0) { if (!optional_out.has_value()) out.fill_(0); return out; } auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); auto K = src.numel() / N; auto E = out.size(dim); auto indptr_info = at::cuda::detail::getTensorInfo(indptr); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] { auto src_data = src.data_ptr(); auto out_data = out.data_ptr(); if (K == 1) gather_csr_kernel<<>>( src_data, indptr_info, out_data, N, E); else gather_csr_broadcast_kernel <<>>(src_data, indptr_info, out_data, N, K, E); }); return out; }