segment_kernel.cu 1.11 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>

#include <thrust/execution_policy.h>

#include "compat.cuh"

rusty1s's avatar
rusty1s committed
11
12
std::tuple<at::Tensor, at::Tensor>
segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
rusty1s's avatar
rusty1s committed
13
14
15
16
17
  cudaSetDevice(src.get_device());
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
  auto policy = thrust::cuda::par(allocator).on(stream);

rusty1s's avatar
rusty1s committed
18
19
  auto key = at::full_like(out, -1, out.options().dtype(at::kLong));

rusty1s's avatar
rusty1s committed
20
  auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
rusty1s's avatar
rusty1s committed
21
22
  auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());

rusty1s's avatar
rusty1s committed
23
24
25
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] {
    auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
    auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
rusty1s's avatar
rusty1s committed
26
27
28

    thrust::reduce_by_key(policy, index_data, index_data + index.size(0),
                          src_data, key_data, out_data);
rusty1s's avatar
rusty1s committed
29
  });
rusty1s's avatar
rusty1s committed
30
31

  return std::make_tuple(out, key);
rusty1s's avatar
rusty1s committed
32
}