#include #include #include #include #include "atomics.cuh" #include "index.cuh" #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS #define KERNEL_RUN(NAME, DIMS, N, ...) \ [&] { \ auto stream = at::cuda::getCurrentCUDAStream(); \ switch (DIMS) { \ case 1: \ NAME<<>>(__VA_ARGS__, N); \ break; \ case 2: \ NAME<<>>(__VA_ARGS__, N); \ break; \ case 3: \ NAME<<>>(__VA_ARGS__, N); \ break; \ default: \ NAME<<>>(__VA_ARGS__, N); \ } \ }() template __global__ void scatter_mul_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomMul(&out.data[outOffset], src.data[srcOffset]); } } void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] { KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(), at::cuda::detail::getTensorInfo(src), at::cuda::detail::getTensorInfo(index), at::cuda::detail::getTensorInfo(out), dim); }); } template __global__ void scatter_div_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomDiv(&out.data[outOffset], src.data[srcOffset]); } } void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] { KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(), at::cuda::detail::getTensorInfo(src), at::cuda::detail::getTensorInfo(index), at::cuda::detail::getTensorInfo(out), dim); }); } template __global__ void arg_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, at::cuda::detail::TensorInfo arg, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0; IndexToScatterOffsets4::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg, &argOffset); if (src.data[srcOffset] == out.data[outOffset]) { arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim]; } } } template __global__ void scatter_max_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomMax(&out.data[outOffset], src.data[srcOffset]); } } void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, torch::Tensor arg, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] { auto src_info = at::cuda::detail::getTensorInfo(src); auto index_info = at::cuda::detail::getTensorInfo(index); auto out_info = at::cuda::detail::getTensorInfo(out); KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info, index_info, out_info, dim); KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info, out_info, at::cuda::detail::getTensorInfo(arg), dim); }); } template __global__ void scatter_min_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomMin(&out.data[outOffset], src.data[srcOffset]); } } void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out, torch::Tensor arg, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] { auto src_info = at::cuda::detail::getTensorInfo(src); auto index_info = at::cuda::detail::getTensorInfo(index); auto out_info = at::cuda::detail::getTensorInfo(out); KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info, index_info, out_info, dim); KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info, out_info, at::cuda::detail::getTensorInfo(arg), dim); }); }