segment_kernel.cu 3.39 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>

#include <thrust/execution_policy.h>

#include "compat.cuh"

rusty1s's avatar
rusty1s committed
11
#define THREADS 256
rusty1s's avatar
rusty1s committed
12
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
13
14
15
#define FULL_MASK 0xffffffff

template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
16
17
18
__global__ void segment_add_csr_kernel(const scalar_t *src_data,
                                       const int64_t *indptr_data,
                                       scalar_t *out_data, size_t numel) {
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int warp_idx = thread_idx / TB;
  int lane_idx = thread_idx & (TB - 1);

  if (warp_idx < numel) {
    int row_start = __ldg(indptr_data + warp_idx);
    int row_end = __ldg(indptr_data + warp_idx + 1);
    scalar_t val = (scalar_t)0;

    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
      val += __ldg(src_data + src_idx);
    }

#pragma unroll
    for (int offset = TB / 2; offset > 0; offset /= 2)
      val += __shfl_down_sync(FULL_MASK, val, offset); // Parallel reduction.

    if (lane_idx == 0) {
      out_data[warp_idx] = val;
    }
  }
}

rusty1s's avatar
rusty1s committed
43
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
rusty1s's avatar
rusty1s committed
44
  auto numel = indptr.numel() - 1;
rusty1s's avatar
rusty1s committed
45
46
  auto avg_length = (float)src.numel() / (float)numel;

rusty1s's avatar
rusty1s committed
47
48
  auto out = at::empty({numel}, src.options());

rusty1s's avatar
rusty1s committed
49
  auto indptr_data = indptr.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
50
51
52
53
54
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    if (avg_length <= 4)
      segment_add_csr_kernel<scalar_t, 4>
          <<<BLOCKS(4, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                     out_data, numel);
    else if (avg_length <= 8)
      segment_add_csr_kernel<scalar_t, 8>
          <<<BLOCKS(8, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                     out_data, numel);
    else if (avg_length <= 16)
      segment_add_csr_kernel<scalar_t, 16>
          <<<BLOCKS(16, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                      out_data, numel);
    else
      segment_add_csr_kernel<scalar_t, 32>
          <<<BLOCKS(32, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                      out_data, numel);
rusty1s's avatar
rusty1s committed
71
72
73
74
75
  });

  return out;
}

rusty1s's avatar
rusty1s committed
76
77
78
79
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index) {
  return src;
}

rusty1s's avatar
rusty1s committed
80
81
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
82
83
84
  auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
  auto policy = thrust::cuda::par(allocator).on(stream);

rusty1s's avatar
rusty1s committed
85
86
  auto key = at::full_like(out, -1, out.options().dtype(at::kLong));

rusty1s's avatar
rusty1s committed
87
  auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
rusty1s's avatar
rusty1s committed
88
89
  auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());

rusty1s's avatar
rusty1s committed
90
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_thrust_kernel", [&] {
rusty1s's avatar
rusty1s committed
91
92
    auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
    auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
rusty1s committed
94
    thrust::reduce_by_key(policy, index_data, index_data + index.numel(),
rusty1s's avatar
rusty1s committed
95
                          src_data, key_data, out_data);
rusty1s's avatar
rusty1s committed
96
97
  });
}