#include #include #include #include #include #include "compat.cuh" void segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out, int64_t dim) { cudaSetDevice(src.get_device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA()); auto policy = thrust::cuda::par(allocator).on(stream); auto index_data = thrust::device_ptr(index.DATA_PTR()); AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] { auto src_data = thrust::device_ptr(src.DATA_PTR()); auto out_data = thrust::device_ptr(out.DATA_PTR()); }); }