segment_kernel.cu 4.76 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>

#include <thrust/execution_policy.h>

rusty1s's avatar
rusty1s committed
9
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
10
11
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
12
#define THREADS 256
rusty1s's avatar
rusty1s committed
13
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
14
15
16
#define FULL_MASK 0xffffffff

template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
17
18
19
__global__ void segment_add_csr_kernel(const scalar_t *src_data,
                                       const int64_t *indptr_data,
                                       scalar_t *out_data, size_t numel) {
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int warp_idx = thread_idx / TB;
  int lane_idx = thread_idx & (TB - 1);

  if (warp_idx < numel) {
    int row_start = __ldg(indptr_data + warp_idx);
    int row_end = __ldg(indptr_data + warp_idx + 1);
    scalar_t val = (scalar_t)0;

    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
      val += __ldg(src_data + src_idx);
    }

#pragma unroll
    for (int offset = TB / 2; offset > 0; offset /= 2)
      val += __shfl_down_sync(FULL_MASK, val, offset); // Parallel reduction.

    if (lane_idx == 0) {
      out_data[warp_idx] = val;
    }
  }
}

rusty1s's avatar
rusty1s committed
44
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
rusty1s's avatar
rusty1s committed
45
  auto numel = indptr.numel() - 1; // TODO
rusty1s's avatar
rusty1s committed
46
47
  auto avg_length = (float)src.numel() / (float)numel;

rusty1s's avatar
rusty1s committed
48
49
  auto out = at::empty({numel}, src.options());

rusty1s's avatar
rusty1s committed
50
  auto indptr_data = indptr.DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
51
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
52
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
53
54
55
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    if (avg_length <= 4)
      segment_add_csr_kernel<scalar_t, 4>
          <<<BLOCKS(4, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                     out_data, numel);
    else if (avg_length <= 8)
      segment_add_csr_kernel<scalar_t, 8>
          <<<BLOCKS(8, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                     out_data, numel);
    else if (avg_length <= 16)
      segment_add_csr_kernel<scalar_t, 16>
          <<<BLOCKS(16, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                      out_data, numel);
    else
      segment_add_csr_kernel<scalar_t, 32>
          <<<BLOCKS(32, numel), THREADS, 0, stream>>>(src_data, indptr_data,
                                                      out_data, numel);
rusty1s's avatar
rusty1s committed
72
73
74
75
76
  });

  return out;
}

rusty1s's avatar
rusty1s committed
77
template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
78
79
80
81
82
__global__ void segment_add_coo_kernel(const scalar_t *src_data,
                                       const int64_t *index_data,
                                       scalar_t *out_data, size_t numel) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
83
  int lane_idx = thread_idx & (TB - 1);
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89

  if (thread_idx < numel) {
    auto idx = __ldg(index_data + thread_idx);
    scalar_t val = src_data[thread_idx], tmp;

#pragma unroll
rusty1s's avatar
rusty1s committed
90
    for (int offset = 1; offset < TB; offset *= 2) {
rusty1s's avatar
rusty1s committed
91
92
93
94
95
96
97
      tmp = __shfl_up_sync(FULL_MASK, val, offset);
      if (lane_idx >= offset &&
          idx == __ldg(index_data + thread_idx - offset)) {
        val += tmp;
      }
    }

rusty1s's avatar
rusty1s committed
98
    if (lane_idx == TB - 1 || idx != __ldg(index_data + thread_idx + 1)) {
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
105
      atomAdd(out_data + idx, val);
    }
  }
}

void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto numel = src.numel();
rusty1s's avatar
rusty1s committed
106
  auto avg_length = (float)numel / (float)out.numel();
rusty1s's avatar
rusty1s committed
107
108
109
110
111
112
113

  auto index_data = index.DATA_PTR<int64_t>();
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
114
115
116
    segment_add_coo_kernel<scalar_t, 32>
        <<<BLOCKS(1, numel), THREADS, 0, stream>>>(src_data, index_data,
                                                   out_data, numel);
rusty1s's avatar
rusty1s committed
117
  });
rusty1s's avatar
rusty1s committed
118
119
}

rusty1s's avatar
rusty1s committed
120
121
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
122
123
124
  auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
  auto policy = thrust::cuda::par(allocator).on(stream);

rusty1s's avatar
rusty1s committed
125
126
  auto key = at::full_like(out, -1, out.options().dtype(at::kLong));

rusty1s's avatar
rusty1s committed
127
  auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
rusty1s's avatar
rusty1s committed
128
129
  auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());

rusty1s's avatar
rusty1s committed
130
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_thrust_kernel", [&] {
rusty1s's avatar
rusty1s committed
131
132
    auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
    auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
rusty1s's avatar
rusty1s committed
133

rusty1s's avatar
rusty1s committed
134
    thrust::reduce_by_key(policy, index_data, index_data + index.numel(),
rusty1s's avatar
rusty1s committed
135
                          src_data, key_data, out_data);
rusty1s's avatar
rusty1s committed
136
137
  });
}