#include #include #include #include "atomics.cuh" #include "index.cuh" #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS #define KERNEL_RUN(NAME, DIMS, N, ...) \ [&] { \ switch (DIMS) { \ case 1: \ NAME<<>>(__VA_ARGS__, N); \ break; \ case 2: \ NAME<<>>(__VA_ARGS__, N); \ break; \ case 3: \ NAME<<>>(__VA_ARGS__, N); \ break; \ default: \ NAME<<>>(__VA_ARGS__, N); \ } \ }() template __global__ void scatter_mul_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomMul(&out.data[outOffset], src.data[srcOffset]); } } void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] { KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(), at::cuda::detail::getTensorInfo(src), at::cuda::detail::getTensorInfo(index), at::cuda::detail::getTensorInfo(out), dim); }); } template __global__ void scatter_div_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomDiv(&out.data[outOffset], src.data[srcOffset]); } } void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] { KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(), at::cuda::detail::getTensorInfo(src), at::cuda::detail::getTensorInfo(index), at::cuda::detail::getTensorInfo(out), dim); }); } template __global__ void arg_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, at::cuda::detail::TensorInfo arg, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0; IndexToScatterOffsets4::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg, &argOffset); if (src.data[srcOffset] == out.data[outOffset]) { arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim]; } } } template __global__ void scatter_max_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomMax(&out.data[outOffset], src.data[srcOffset]); } } void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out, at::Tensor arg, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] { auto src_info = at::cuda::detail::getTensorInfo(src); auto index_info = at::cuda::detail::getTensorInfo(index); auto out_info = at::cuda::detail::getTensorInfo(out); KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info, index_info, out_info, dim); KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info, out_info, at::cuda::detail::getTensorInfo(arg), dim); }); } template __global__ void scatter_min_kernel(at::cuda::detail::TensorInfo src, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t srcOffset = 0, indexOffset = 0, outOffset = 0; IndexToScatterOffsets3::compute( i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset); atomMin(&out.data[outOffset], src.data[srcOffset]); } } void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out, at::Tensor arg, int64_t dim) { cudaSetDevice(src.get_device()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] { auto src_info = at::cuda::detail::getTensorInfo(src); auto index_info = at::cuda::detail::getTensorInfo(index); auto out_info = at::cuda::detail::getTensorInfo(out); KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info, index_info, out_info, dim); KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info, out_info, at::cuda::detail::getTensorInfo(arg), dim); }); } template __global__ void index_backward_kernel(at::cuda::detail::TensorInfo grad, at::cuda::detail::TensorInfo index, at::cuda::detail::TensorInfo arg, at::cuda::detail::TensorInfo out, int64_t dim, size_t numel) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (ptrdiff_t i = idx; i < numel; i += stride) { int64_t gradOffset = 0, indexOffset = 0, argOffset = 0, outOffset = 0; IndexToScatterOffsets4::compute( i, dim, index, &indexOffset, out, &outOffset, arg, &argOffset, grad, &gradOffset); if (arg.data[argOffset] == (outOffset / out.strides[dim]) % out.sizes[dim]) { out.data[outOffset] = grad.data[gradOffset]; } } } void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg, at::Tensor out, int64_t dim) { cudaSetDevice(grad.get_device()); AT_DISPATCH_ALL_TYPES(grad.scalar_type(), "index_backward_kernel", [&] { KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(), at::cuda::detail::getTensorInfo(grad), at::cuda::detail::getTensorInfo(index), at::cuda::detail::getTensorInfo(arg), at::cuda::detail::getTensorInfo(out), dim); }); }