#include #include #include #include #include #include "compat.cuh" std::tuple segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out) { cudaSetDevice(src.get_device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA()); auto policy = thrust::cuda::par(allocator).on(stream); auto key = at::full_like(out, -1, out.options().dtype(at::kLong)); auto index_data = thrust::device_ptr(index.DATA_PTR()); auto key_data = thrust::device_ptr(key.DATA_PTR()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] { auto src_data = thrust::device_ptr(src.DATA_PTR()); auto out_data = thrust::device_ptr(out.DATA_PTR()); thrust::reduce_by_key(policy, index_data, index_data + index.size(0), src_data, key_data, out_data); }); return std::make_tuple(out, key); }