segment_kernel.cu 4.64 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>

#include <thrust/execution_policy.h>

rusty1s's avatar
rusty1s committed
9
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
10
11
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
12
#define THREADS 256
rusty1s's avatar
rusty1s committed
13
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
14
15
16
#define FULL_MASK 0xffffffff

template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
17
18
19
__global__ void segment_add_csr_kernel(const scalar_t *src_data,
                                       const int64_t *indptr_data,
                                       scalar_t *out_data, size_t numel) {
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int warp_idx = thread_idx / TB;
  int lane_idx = thread_idx & (TB - 1);

  if (warp_idx < numel) {
    int row_start = __ldg(indptr_data + warp_idx);
    int row_end = __ldg(indptr_data + warp_idx + 1);
    scalar_t val = (scalar_t)0;

    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
      val += __ldg(src_data + src_idx);
    }

#pragma unroll
    for (int offset = TB / 2; offset > 0; offset /= 2)
      val += __shfl_down_sync(FULL_MASK, val, offset); // Parallel reduction.

    if (lane_idx == 0) {
      out_data[warp_idx] = val;
    }
  }
}

rusty1s's avatar
rusty1s committed
44
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
rusty1s's avatar
rusty1s committed
45
  auto numel = indptr.numel() - 1; // TODO
rusty1s's avatar
rusty1s committed
46
47
  auto avg_length = (float)src.numel() / (float)numel;

rusty1s's avatar
rusty1s committed
48
49
  auto out = at::empty({numel}, src.options());

rusty1s's avatar
rusty1s committed
50
  auto indptr_data = indptr.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
51
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
52
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
53
54
55
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    if (avg_length <= 4)
      segment_add_csr_kernel<scalar_t, 4>
          <<<BLOCKS(4, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                     out_data, numel);
    else if (avg_length <= 8)
      segment_add_csr_kernel<scalar_t, 8>
          <<<BLOCKS(8, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                     out_data, numel);
    else if (avg_length <= 16)
      segment_add_csr_kernel<scalar_t, 16>
          <<<BLOCKS(16, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                      out_data, numel);
    else
      segment_add_csr_kernel<scalar_t, 32>
          <<<BLOCKS(32, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                      out_data, numel);
rusty1s's avatar
rusty1s committed
72
73
74
75
76
  });

  return out;
}

rusty1s's avatar
rusty1s committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
template <typename scalar_t>
__global__ void segment_add_coo_kernel(const scalar_t *src_data,
                                       const int64_t *index_data,
                                       scalar_t *out_data, size_t numel) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = thread_idx & (32 - 1);

  if (thread_idx < numel) {
    auto idx = __ldg(index_data + thread_idx);
    scalar_t val = src_data[thread_idx], tmp;

#pragma unroll
    for (int offset = 1; offset < 32; offset *= 2) {
      tmp = __shfl_up_sync(FULL_MASK, val, offset);
      if (lane_idx >= offset &&
          idx == __ldg(index_data + thread_idx - offset)) {
        val += tmp;
      }
    }

    if (lane_idx == 31 || idx != __ldg(index_data + thread_idx + 1)) {
      atomAdd(out_data + idx, val);
    }
  }
}

void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto numel = src.numel();

  auto index_data = index.DATA_PTR<int64_t>();
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

    segment_add_coo_kernel<scalar_t><<<BLOCKS(1, numel), THREADS, 0, stream>>>(
        src_data, index_data, out_data, numel);
  });
rusty1s's avatar
rusty1s committed
116
117
}

rusty1s's avatar
rusty1s committed
118
119
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
120
121
122
  auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
  auto policy = thrust::cuda::par(allocator).on(stream);

rusty1s's avatar
rusty1s committed
123
124
  auto key = at::full_like(out, -1, out.options().dtype(at::kLong));

rusty1s's avatar
rusty1s committed
125
  auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
rusty1s's avatar
rusty1s committed
126
127
  auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());

rusty1s's avatar
rusty1s committed
128
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_thrust_kernel", [&] {
rusty1s's avatar
rusty1s committed
129
130
    auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
    auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
rusty1s's avatar
rusty1s committed
131

rusty1s's avatar
rusty1s committed
132
    thrust::reduce_by_key(policy, index_data, index_data + index.numel(),
rusty1s's avatar
rusty1s committed
133
                          src_data, key_data, out_data);
rusty1s's avatar
rusty1s committed
134
135
  });
}