segment_kernel.cu 9.74 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
7
8
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
9
#define THREADS 256
rusty1s's avatar
rusty1s committed
10
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
11
12
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
13
14
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
rusty1s's avatar
rusty1s committed
15
16
17
18
template <typename T, typename I> struct IndexPtrToOffset {
  static __host__ __device__ I
  get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
    I offset = idx % (info.sizes[info.dims - 1] - 1);
rusty1s's avatar
rusty1s committed
19
    offset *= info.strides[info.dims - 1];
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

rusty1s's avatar
rusty1s committed
29
template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
30
31
32
33
__global__ void segment_add_csr_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t E) {
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
36
37
  // Each warp processes exactly `32/TB` rows. We usually set `TB=32` and only
  // make use of it in case the average row length is less than 32.

rusty1s's avatar
rusty1s committed
38
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
39
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
40
41
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
42
  if (row_idx < N) {
rusty1s's avatar
rusty1s committed
43
    int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
44
    int row_start = __ldg(indptr_info.data + offset);
rusty1s's avatar
rusty1s committed
45
46
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
47
48
    scalar_t val = (scalar_t)0;

rusty1s's avatar
rusty1s committed
49
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
50
    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
rusty1s's avatar
rusty1s committed
51
      val += src_data[offset + src_idx]; // "Mostly" coalesced read.
rusty1s's avatar
rusty1s committed
52
53
54
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
55
56
57
58
    for (int i = TB / 2; i > 0; i /= 2) {
      // Parallel reduction inside a single warp.
      val += __shfl_down_sync(FULL_MASK, val, i);
    }
rusty1s's avatar
rusty1s committed
59
60

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
61
      out_data[row_idx] = val; // "Mostly" coalesced write.
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
68
69
70
71
    }
  }
}

template <typename scalar_t>
__global__ void segment_add_csr_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t K, size_t E) {

rusty1s's avatar
rusty1s committed
72
73
74
  // Each thread processes exactly one row. It turned out that is more efficient
  // than using shared memory due to avoiding synchronization barriers.

rusty1s's avatar
rusty1s committed
75
76
77
78
79
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
rusty1s's avatar
rusty1s committed
80
    int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
81
82
83
84
85
86
87
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
    scalar_t val = (scalar_t)0;

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int src_idx = row_start; src_idx < row_end; src_idx++) {
rusty1s's avatar
rusty1s committed
88
      val += src_data[offset + K * src_idx + lane_idx]; // Coalesced read.
rusty1s's avatar
rusty1s committed
89
90
    }

rusty1s's avatar
rusty1s committed
91
    out_data[thread_idx] = val; // Coalesced write.
rusty1s's avatar
rusty1s committed
92
93
94
  }
}

95
96
97
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
                                at::optional<at::Tensor> out_opt) {

rusty1s's avatar
rusty1s committed
98
  AT_ASSERTM(src.dim() >= indptr.dim());
rusty1s's avatar
rusty1s committed
99
100
101
  for (int i = 0; i < indptr.dim() - 1; i++)
    AT_ASSERTM(src.size(i) == indptr.size(i));

rusty1s's avatar
rusty1s committed
102
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
103
  auto reduce_dim = indptr.dim() - 1;
104
105
106
107
108
109
110
111
112
113
114
115
116

  at::Tensor out;
  if (out_opt.has_value()) {
    out = out_opt.value();
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
        AT_ASSERTM(src.size(i) == out.size(i));
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1);
  } else {
    auto sizes = src.sizes().vec();
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
    out = at::empty(sizes, src.options());
  }
rusty1s's avatar
rusty1s committed
117

rusty1s's avatar
rusty1s committed
118
119
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
120
121
  auto E = src.size(reduce_dim);
  auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
122

rusty1s's avatar
rusty1s committed
123
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
124
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
125
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
126
127
128
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
129
130
    // Select the right kernel based on average row length and whether we need
    // broadcasting capabilties (K > 1):
rusty1s's avatar
rusty1s committed
131
    if (K == 1 && avg_length <= 4) {
rusty1s's avatar
rusty1s committed
132
133
      segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
          src_data, indptr_info, out_data, N, E);
rusty1s's avatar
rusty1s committed
134
    } else if (K == 1 && avg_length <= 8) {
rusty1s's avatar
rusty1s committed
135
136
      segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>(
          src_data, indptr_info, out_data, N, E);
rusty1s's avatar
rusty1s committed
137
    } else if (K == 1 && avg_length <= 16) {
rusty1s's avatar
rusty1s committed
138
      segment_add_csr_kernel<scalar_t, 16>
rusty1s's avatar
rusty1s committed
139
140
          <<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, N, E);
rusty1s's avatar
rusty1s committed
141
    } else if (K == 1) {
rusty1s's avatar
rusty1s committed
142
      segment_add_csr_kernel<scalar_t, 32>
rusty1s's avatar
rusty1s committed
143
144
          <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, N, E);
rusty1s's avatar
rusty1s committed
145
146
147
148
149
    } else {
      segment_add_csr_broadcast_kernel<scalar_t>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
                                                     out_data, N, K, E);
    }
rusty1s's avatar
rusty1s committed
150
151
152
153
154
  });

  return out;
}

rusty1s's avatar
rusty1s committed
155
156
157
158
159
template <typename scalar_t>
__global__ void segment_add_coo_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, size_t E) {
rusty1s's avatar
rusty1s committed
160

rusty1s's avatar
rusty1s committed
161
162
163
164
165
166
  // Each thread processes exactly one entry. Within a warp, we perform a
  // parallel reduction across equal indices, and write the intermediate
  // result via atomics.

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = row_idx & (32 - 1);
rusty1s's avatar
rusty1s committed
167

rusty1s's avatar
rusty1s committed
168
169
170
171
172
  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int idx = index_info.data[offset], next_idx;
    scalar_t val = src_data[row_idx], tmp;
rusty1s's avatar
rusty1s committed
173
174

#pragma unroll
rusty1s's avatar
rusty1s committed
175
176
177
    for (int i = 1; i < 32; i *= 2) {
      tmp = __shfl_up_sync(FULL_MASK, val, i);
      next_idx = __shfl_up_sync(FULL_MASK, idx, i);
178
      assert(idx >= next_idx);
rusty1s's avatar
rusty1s committed
179
      if (lane_idx >= i && idx == next_idx)
rusty1s's avatar
rusty1s committed
180
181
182
        val += tmp;
    }

rusty1s's avatar
rusty1s committed
183
184
    next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
    if (lane_idx == 32 - 1 || idx != next_idx) {
rusty1s's avatar
rusty1s committed
185
186
187
188
189
      atomAdd(out_data + idx, val);
    }
  }
}

rusty1s's avatar
rusty1s committed
190
191
192
193
194
template <typename scalar_t, int TB>
__global__ void segment_add_coo_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, size_t E, size_t K) {
rusty1s's avatar
rusty1s committed
195

rusty1s's avatar
rusty1s committed
196
197
198
  // Each thread processes a single column and `TB` rows. Coalesced read and
  // write is performed in column-major order. The intermediate results are
  // written via atomics.
rusty1s's avatar
rusty1s committed
199

rusty1s's avatar
rusty1s committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
  int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
  int col_idx = blockIdx.y * blockDim.x + threadIdx.x;

  if (row_start < E && col_idx < K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_start, index_info);

    int idx1 = __ldg(index_info.data + offset);
    scalar_t val = src_data[K * row_start + col_idx];

#pragma unroll
    for (int i = 1; i < TB; i++) {
      if (row_start + i >= E)
        break;

      int idx2 = __ldg(index_info.data + offset +
                       i * index_info.strides[index_info.dims - 1]);
217
      assert(idx1 <= idx2);
rusty1s's avatar
rusty1s committed
218
219
220
221
222
223
224
225
226
227
228
      if (idx1 == idx2) {
        val += src_data[K * (row_start + i) + col_idx];
      } else {
        atomAdd(out_data + K * idx1 + col_idx, val);
        val = src_data[K * (row_start + i) + col_idx];
      }
      idx1 = idx2;
    }

    atomAdd(out_data + K * idx1 + col_idx, val);
  }
rusty1s's avatar
rusty1s committed
229
230
}

231
232
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index,
                                at::Tensor out) {
rusty1s's avatar
rusty1s committed
233
234
235
236
237
238
  AT_ASSERTM(src.dim() >= index.dim());
  for (int i = 0; i < index.dim(); i++)
    AT_ASSERTM(src.size(i) == index.size(i));

  src = src.contiguous();
  auto reduce_dim = index.dim() - 1;
rusty1s's avatar
rusty1s committed
239

rusty1s's avatar
rusty1s committed
240
241
242
  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
      AT_ASSERTM(src.size(i) == out.size(i));
rusty1s's avatar
rusty1s committed
243

rusty1s's avatar
rusty1s committed
244
245
246
  auto E = index.numel();
  auto K = src.numel() / index.numel();
  auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
247

rusty1s's avatar
rusty1s committed
248
249
250
251
252
  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
253

rusty1s's avatar
rusty1s committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    if (K == 1)
      segment_add_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, E);
    else if (avg_length <= 8)
      segment_add_coo_broadcast_kernel<scalar_t, 4>
          <<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0,
             stream>>>(src_data, index_info, out_data, E, K);
    else if (avg_length <= 16)
      segment_add_coo_broadcast_kernel<scalar_t, 8>
          <<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0,
             stream>>>(src_data, index_info, out_data, E, K);
    else if (avg_length <= 32)
      segment_add_coo_broadcast_kernel<scalar_t, 16>
          <<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
             0, stream>>>(src_data, index_info, out_data, E, K);
    else
      segment_add_coo_broadcast_kernel<scalar_t, 32>
          <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
             0, stream>>>(src_data, index_info, out_data, E, K);
rusty1s's avatar
rusty1s committed
273
  });
274
275

  return out;
rusty1s's avatar
rusty1s committed
276
}