segment_kernel.cu 14.3 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
7
8
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
9
#define THREADS 256
rusty1s's avatar
rusty1s committed
10
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
11
12
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
13
14
15
16
17
#define ADD 0
#define MEAN 1
#define MIN 2
#define MAX 3

rusty1s's avatar
rusty1s committed
18
19
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
rusty1s's avatar
rusty1s committed
20
21
22
23
template <typename T, typename I> struct IndexPtrToOffset {
  static __host__ __device__ I
  get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
    I offset = idx % (info.sizes[info.dims - 1] - 1);
rusty1s's avatar
rusty1s committed
24
    offset *= info.strides[info.dims - 1];
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
33
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

rusty1s's avatar
rusty1s committed
34
template <typename scalar_t, int REDUCE, int TB>
rusty1s's avatar
rusty1s committed
35
36
37
__global__ void segment_add_csr_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
rusty1s's avatar
rusty1s committed
38
    scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t E) {
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
  // Each warp processes exactly `32/TB` rows.
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
43
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
44
45
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
46
  if (row_idx < N) {
rusty1s's avatar
rusty1s committed
47
    int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
48
    int row_start = __ldg(indptr_info.data + offset);
rusty1s's avatar
rusty1s committed
49
50
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
60
61
62

    scalar_t val, tmp;
    int64_t arg_val, arg_tmp;
    if (REDUCE == ADD) {
      val = (scalar_t)0;
    } else if (REDUCE == MEAN) {
      val = (scalar_t)0;
    } else if (REDUCE == MIN) {
      val = std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      val = std::numeric_limits<scalar_t>::min();
    }
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
65
    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
73
74
75
76
77
78
      tmp = src_data[offset + src_idx]; // "Mostly" coalesced read.

      if (REDUCE == ADD) {
        val += tmp;
      } else if (REDUCE == MEAN) {
        val += tmp;
      } else if (REDUCE == MIN && tmp < val) {
        val = tmp;
        arg_val = src_idx;
      } else if (REDUCE == MAX && tmp > val) {
        val = tmp;
        arg_val = src_idx;
      }
rusty1s's avatar
rusty1s committed
79
80
81
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
82
83
    for (int i = TB / 2; i > 0; i /= 2) {
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
      tmp = __shfl_down_sync(FULL_MASK, val, i);

      if (REDUCE == ADD) {
        val += tmp;
      } else if (REDUCE == MEAN) {
        val += tmp;
      } else if (REDUCE == MIN) {
        arg_tmp = __shfl_down_sync(FULL_MASK, arg_val, i);
        if (tmp < val) {
          val = tmp;
          arg_val = arg_tmp;
        }
      } else if (REDUCE == MAX) {
        arg_tmp = __shfl_down_sync(FULL_MASK, arg_val, i);
        if (tmp > val) {
          val = tmp;
          arg_val = arg_tmp;
        }
      }
rusty1s's avatar
rusty1s committed
103
    }
rusty1s's avatar
rusty1s committed
104
105

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
      // "Mostly" coalesced write.
      if (REDUCE == ADD) {
        out_data[row_idx] = val;
      } else if (REDUCE == MEAN) {
        out_data[row_idx] = val / (scalar_t)max(row_end - row_start, 1);
      } else if (REDUCE == MIN) {
        if (row_end - row_start > 0) {
          out_data[row_idx] = val;
          arg_out_data[row_idx] = arg_val;
        } else {
          out_data[row_idx] = 0;
        }
      } else if (REDUCE == MAX) {
        if (row_end - row_start > 0) {
          out_data[row_idx] = val;
          arg_out_data[row_idx] = arg_val;
        } else {
          out_data[row_idx] = 0;
        }
      }
rusty1s's avatar
rusty1s committed
126
127
128
129
    }
  }
}

rusty1s's avatar
rusty1s committed
130
template <typename scalar_t, int REDUCE>
rusty1s's avatar
rusty1s committed
131
132
133
__global__ void segment_add_csr_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
rusty1s's avatar
rusty1s committed
134
    scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
rusty1s's avatar
rusty1s committed
135

rusty1s's avatar
rusty1s committed
136
137
138
  // Each thread processes exactly one row. It turned out that is more efficient
  // than using shared memory due to avoiding synchronization barriers.

rusty1s's avatar
rusty1s committed
139
140
141
142
143
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
rusty1s's avatar
rusty1s committed
144
    int offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
145
146
147
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
148
149
150
151
152
153
154
155
156
157
158
159

    scalar_t val, tmp;
    int64_t arg_val;
    if (REDUCE == ADD) {
      val = (scalar_t)0;
    } else if (REDUCE == MEAN) {
      val = (scalar_t)0;
    } else if (REDUCE == MIN) {
      val = std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      val = std::numeric_limits<scalar_t>::min();
    }
rusty1s's avatar
rusty1s committed
160
161
162

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int src_idx = row_start; src_idx < row_end; src_idx++) {
rusty1s's avatar
rusty1s committed
163
164
165
166
167
168
169
170
171
172
173
174
175
      tmp = src_data[offset + K * src_idx + lane_idx]; // Coalesced read.

      if (REDUCE == ADD) {
        val += tmp;
      } else if (REDUCE == MEAN) {
        val += tmp;
      } else if (REDUCE == MIN && tmp < val) {
        val = tmp;
        arg_val = src_idx;
      } else if (REDUCE == MAX && tmp > val) {
        val = tmp;
        arg_val = src_idx;
      }
rusty1s's avatar
rusty1s committed
176
177
    }

rusty1s's avatar
rusty1s committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    // Coalesced write.
    if (REDUCE == ADD) {
      out_data[thread_idx] = val;
    } else if (REDUCE == MEAN) {
      out_data[thread_idx] = val / (scalar_t)max(row_end - row_start, 1);
    } else if (REDUCE == MIN) {
      if (row_end - row_start > 0) {
        out_data[thread_idx] = val;
        arg_out_data[thread_idx] = arg_val;
      } else {
        out_data[thread_idx] = 0;
      }
    } else if (REDUCE == MAX) {
      if (row_end - row_start > 0) {
        out_data[thread_idx] = val;
        arg_out_data[thread_idx] = arg_val;
      } else {
        out_data[thread_idx] = 0;
      }
    }
rusty1s's avatar
rusty1s committed
198
199
200
  }
}

rusty1s's avatar
rusty1s committed
201
202
203
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
                 at::optional<at::Tensor> out_opt, std::string reduce) {
204

rusty1s's avatar
rusty1s committed
205
  AT_ASSERTM(src.dim() >= indptr.dim());
rusty1s's avatar
rusty1s committed
206
207
208
  for (int i = 0; i < indptr.dim() - 1; i++)
    AT_ASSERTM(src.size(i) == indptr.size(i));

rusty1s's avatar
rusty1s committed
209
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
210
  auto reduce_dim = indptr.dim() - 1;
211
212
213

  at::Tensor out;
  if (out_opt.has_value()) {
rusty1s's avatar
rusty1s committed
214
    out = out_opt.value().contiguous();
215
216
217
218
219
220
221
222
223
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
        AT_ASSERTM(src.size(i) == out.size(i));
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1);
  } else {
    auto sizes = src.sizes().vec();
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
    out = at::empty(sizes, src.options());
  }
rusty1s's avatar
rusty1s committed
224

rusty1s's avatar
rusty1s committed
225
226
227
228
229
  at::optional<at::Tensor> arg_out = at::nullopt;
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
  }

rusty1s's avatar
rusty1s committed
230
231
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
232
  auto E = src.size(reduce_dim);
rusty1s's avatar
rusty1s committed
233
  // auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
234

rusty1s's avatar
rusty1s committed
235
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
236
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
237
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
238
239
240
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
241
    // Select the right kernel based on the reduce operation and whether we need
rusty1s's avatar
rusty1s committed
242
    // broadcasting capabilties (K > 1):
rusty1s's avatar
rusty1s committed
243
244
245
246
247
248
249
250
251
252
253
254

    if (K == 1 && reduce == "add") {
      segment_add_csr_kernel<scalar_t, ADD, 1>
          <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, nullptr, N, E);
    } else if (K == 1 && reduce == "mean") {
      segment_add_csr_kernel<scalar_t, MEAN, 1>
          <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, nullptr, N, E);
    } else if (K == 1 && reduce == "min") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_add_csr_kernel<scalar_t, MIN, 1>
rusty1s's avatar
rusty1s committed
255
          <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
rusty1s's avatar
rusty1s committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                                                  out_data, arg_out_data, N, E);
    } else if (K == 1 && reduce == "max") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_add_csr_kernel<scalar_t, MAX, 1>
          <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, arg_out_data, N, E);
    } else if (reduce == "add") {
      segment_add_csr_broadcast_kernel<scalar_t, ADD>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
              src_data, indptr_info, out_data, nullptr, N, K, E);
    } else if (reduce == "mean") {
      segment_add_csr_broadcast_kernel<scalar_t, MEAN>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
              src_data, indptr_info, out_data, nullptr, N, K, E);
    } else if (reduce == "min") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_add_csr_broadcast_kernel<scalar_t, MIN>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
              src_data, indptr_info, out_data, arg_out_data, N, K, E);
    } else if (reduce == "max") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_add_csr_broadcast_kernel<scalar_t, MAX>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
              src_data, indptr_info, out_data, arg_out_data, N, K, E);
rusty1s's avatar
rusty1s committed
280
    }
rusty1s's avatar
rusty1s committed
281
282
  });

rusty1s's avatar
rusty1s committed
283
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
284
285
}

rusty1s's avatar
rusty1s committed
286
template <typename scalar_t, int REDUCE>
rusty1s's avatar
rusty1s committed
287
288
289
__global__ void segment_add_coo_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
rusty1s's avatar
rusty1s committed
290
    scalar_t *out_data, int64_t *arg_out_data, size_t E) {
rusty1s's avatar
rusty1s committed
291

rusty1s's avatar
rusty1s committed
292
293
294
295
296
297
  // Each thread processes exactly one entry. Within a warp, we perform a
  // parallel reduction across equal indices, and write the intermediate
  // result via atomics.

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = row_idx & (32 - 1);
rusty1s's avatar
rusty1s committed
298

rusty1s's avatar
rusty1s committed
299
300
301
302
303
  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int idx = index_info.data[offset], next_idx;
    scalar_t val = src_data[row_idx], tmp;
rusty1s's avatar
rusty1s committed
304
305

#pragma unroll
rusty1s's avatar
rusty1s committed
306
307
308
    for (int i = 1; i < 32; i *= 2) {
      tmp = __shfl_up_sync(FULL_MASK, val, i);
      next_idx = __shfl_up_sync(FULL_MASK, idx, i);
309
      assert(idx >= next_idx);
rusty1s's avatar
rusty1s committed
310
      if (lane_idx >= i && idx == next_idx)
rusty1s's avatar
rusty1s committed
311
312
313
        val += tmp;
    }

rusty1s's avatar
rusty1s committed
314
315
    next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
    if (lane_idx == 32 - 1 || idx != next_idx) {
rusty1s's avatar
rusty1s committed
316
317
318
319
320
      atomAdd(out_data + idx, val);
    }
  }
}

rusty1s's avatar
rusty1s committed
321
template <typename scalar_t, int REDUCE, int TB>
rusty1s's avatar
rusty1s committed
322
323
324
__global__ void segment_add_coo_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
rusty1s's avatar
rusty1s committed
325
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) {
rusty1s's avatar
rusty1s committed
326

rusty1s's avatar
rusty1s committed
327
328
329
  // Each thread processes a single column and `TB` rows. Coalesced read and
  // write is performed in column-major order. The intermediate results are
  // written via atomics.
rusty1s's avatar
rusty1s committed
330

rusty1s's avatar
rusty1s committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
  int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
  int col_idx = blockIdx.y * blockDim.x + threadIdx.x;

  if (row_start < E && col_idx < K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_start, index_info);

    int idx1 = __ldg(index_info.data + offset);
    scalar_t val = src_data[K * row_start + col_idx];

#pragma unroll
    for (int i = 1; i < TB; i++) {
      if (row_start + i >= E)
        break;

      int idx2 = __ldg(index_info.data + offset +
                       i * index_info.strides[index_info.dims - 1]);
348
      assert(idx1 <= idx2);
rusty1s's avatar
rusty1s committed
349
350
351
352
353
354
355
356
357
358
359
      if (idx1 == idx2) {
        val += src_data[K * (row_start + i) + col_idx];
      } else {
        atomAdd(out_data + K * idx1 + col_idx, val);
        val = src_data[K * (row_start + i) + col_idx];
      }
      idx1 = idx2;
    }

    atomAdd(out_data + K * idx1 + col_idx, val);
  }
rusty1s's avatar
rusty1s committed
360
361
}

rusty1s's avatar
rusty1s committed
362
363
364
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                 std::string reduce) {
rusty1s's avatar
rusty1s committed
365
366
367
368
369
  AT_ASSERTM(src.dim() >= index.dim());
  for (int i = 0; i < index.dim(); i++)
    AT_ASSERTM(src.size(i) == index.size(i));

  src = src.contiguous();
rusty1s's avatar
rusty1s committed
370
  out = out.contiguous();
rusty1s's avatar
rusty1s committed
371
  auto reduce_dim = index.dim() - 1;
rusty1s's avatar
rusty1s committed
372

rusty1s's avatar
rusty1s committed
373
374
375
  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
      AT_ASSERTM(src.size(i) == out.size(i));
rusty1s's avatar
rusty1s committed
376

rusty1s's avatar
rusty1s committed
377
378
379
380
381
  at::optional<at::Tensor> arg_out = at::nullopt;
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), index.options());
  }

rusty1s's avatar
rusty1s committed
382
383
  auto E = index.numel();
  auto K = src.numel() / index.numel();
rusty1s's avatar
rusty1s committed
384
  auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
385

rusty1s's avatar
rusty1s committed
386
387
  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
388
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
rusty1s's avatar
rusty1s committed
389
390
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
391

rusty1s's avatar
rusty1s committed
392
393
394
    // Select the right kernel based on average row length (purely heuristic)
    // and whether we need broadcasting capabilties (K > 1):

rusty1s's avatar
rusty1s committed
395
    if (K == 1)
rusty1s's avatar
rusty1s committed
396
397
398
399
400
      segment_add_coo_kernel<scalar_t, ADD>
          <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info, out_data,
                                                 nullptr, E);
    else if (avg_len <= 8)
      segment_add_coo_broadcast_kernel<scalar_t, ADD, 4>
rusty1s's avatar
rusty1s committed
401
          <<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0,
rusty1s's avatar
rusty1s committed
402
403
404
             stream>>>(src_data, index_info, out_data, nullptr, E, K);
    else if (avg_len <= 16)
      segment_add_coo_broadcast_kernel<scalar_t, ADD, 8>
rusty1s's avatar
rusty1s committed
405
          <<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0,
rusty1s's avatar
rusty1s committed
406
407
408
             stream>>>(src_data, index_info, out_data, nullptr, E, K);
    else if (avg_len <= 32)
      segment_add_coo_broadcast_kernel<scalar_t, ADD, 16>
rusty1s's avatar
rusty1s committed
409
          <<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
rusty1s's avatar
rusty1s committed
410
             0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
rusty1s's avatar
rusty1s committed
411
    else
rusty1s's avatar
rusty1s committed
412
      segment_add_coo_broadcast_kernel<scalar_t, ADD, 32>
rusty1s's avatar
rusty1s committed
413
          <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
rusty1s's avatar
rusty1s committed
414
             0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
rusty1s's avatar
rusty1s committed
415
  });
416

rusty1s's avatar
rusty1s committed
417
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
418
}