segment_kernel.cu 13.7 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
7
8
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
9
#define THREADS 256
rusty1s's avatar
rusty1s committed
10
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
11
12
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
enum ReductionType { ADD, MEAN, MIN, MAX };

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    if (reduce == "add") {                                                     \
      const ReductionType REDUCE = ADD;                                        \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "mean") {                                             \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "min") {                                              \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "max") {                                              \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline __host__ __device__ scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::min();
    } else {
      return (scalar_t)0;
    }
  }

  static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
                                                int64_t *arg, int64_t new_arg) {
    if ((REDUCE == MIN && new_val < *val) ||
        (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    } else {
      *val = *val + new_val;
    }
  }

  static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
                                               int64_t *arg_address,
                                               int64_t arg, int count) {
    if (REDUCE == ADD) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (scalar_t)max(count, 1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
};
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72
73
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
rusty1s's avatar
rusty1s committed
74
75
76
77
template <typename scalar_t> struct IndexPtrToOffset {
  static inline __host__ __device__ int
  get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
    int offset = idx % (info.sizes[info.dims - 1] - 1);
rusty1s's avatar
rusty1s committed
78
    offset *= info.strides[info.dims - 1];
rusty1s's avatar
rusty1s committed
79
80
81
82
83
84
85
86
87
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

rusty1s's avatar
rusty1s committed
88
89
90
91
92
93
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
                   scalar_t *out_data, int64_t *arg_out_data, size_t N,
                   size_t E) {
rusty1s's avatar
rusty1s committed
94

rusty1s's avatar
rusty1s committed
95
96
  // Each warp processes exactly `32/TB` rows and aggregates all row values via
  // a parallel reduction.
rusty1s's avatar
rusty1s committed
97

rusty1s's avatar
rusty1s committed
98
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
99
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
100
101
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
102
  if (row_idx < N) {
rusty1s's avatar
rusty1s committed
103
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
104
    int row_start = __ldg(indptr_info.data + offset);
rusty1s's avatar
rusty1s committed
105
106
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
107

rusty1s's avatar
rusty1s committed
108
109
    scalar_t val = Reducer<scalar_t, REDUCE>::init();
    int64_t arg, tmp;
rusty1s's avatar
rusty1s committed
110

rusty1s's avatar
rusty1s committed
111
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
112
    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
rusty1s's avatar
rusty1s committed
113
114
      Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
                                        src_idx);
rusty1s's avatar
rusty1s committed
115
116
117
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
118
119
    for (int i = TB / 2; i > 0; i /= 2) {
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
120
121
      if (REDUCE == MIN || REDUCE == MAX) {
        tmp = __shfl_down_sync(FULL_MASK, arg, i);
rusty1s's avatar
rusty1s committed
122
      }
rusty1s's avatar
rusty1s committed
123
124
      Reducer<scalar_t, REDUCE>::update(
          &val, __shfl_down_sync(FULL_MASK, val, i), &arg, tmp);
rusty1s's avatar
rusty1s committed
125
    }
rusty1s's avatar
rusty1s committed
126
127

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
128
129
130
      Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
                                       arg_out_data + row_idx, arg,
                                       row_end - row_start);
rusty1s's avatar
rusty1s committed
131
132
133
134
    }
  }
}

rusty1s's avatar
rusty1s committed
135
136
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
rusty1s's avatar
rusty1s committed
137
138
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
rusty1s's avatar
rusty1s committed
139
    scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
rusty1s's avatar
rusty1s committed
140

rusty1s's avatar
rusty1s committed
141
142
143
  // Each thread processes exactly one row. It turned out that is more
  // efficient than using shared memory due to avoiding synchronization
  // barriers.
rusty1s's avatar
rusty1s committed
144

rusty1s's avatar
rusty1s committed
145
146
147
148
149
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
rusty1s's avatar
rusty1s committed
150
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
151
152
153
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
    scalar_t val = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;
rusty1s's avatar
rusty1s committed
157
158
159

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int src_idx = row_start; src_idx < row_end; src_idx++) {
rusty1s's avatar
rusty1s committed
160
161
      Reducer<scalar_t, REDUCE>::update(
          &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
rusty1s's avatar
rusty1s committed
162
163
    }

rusty1s's avatar
rusty1s committed
164
165
166
    Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
                                     arg_out_data + thread_idx, arg,
                                     row_end - row_start);
rusty1s's avatar
rusty1s committed
167
168
169
  }
}

rusty1s's avatar
rusty1s committed
170
171
172
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
                 at::optional<at::Tensor> out_opt, std::string reduce) {
173

rusty1s's avatar
rusty1s committed
174
  AT_ASSERTM(src.dim() >= indptr.dim());
rusty1s's avatar
rusty1s committed
175
176
177
  for (int i = 0; i < indptr.dim() - 1; i++)
    AT_ASSERTM(src.size(i) == indptr.size(i));

rusty1s's avatar
rusty1s committed
178
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
179
  auto reduce_dim = indptr.dim() - 1;
180
181
182

  at::Tensor out;
  if (out_opt.has_value()) {
rusty1s's avatar
rusty1s committed
183
    out = out_opt.value().contiguous();
184
185
186
187
188
189
190
191
192
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
        AT_ASSERTM(src.size(i) == out.size(i));
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1);
  } else {
    auto sizes = src.sizes().vec();
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
    out = at::empty(sizes, src.options());
  }
rusty1s's avatar
rusty1s committed
193

rusty1s's avatar
rusty1s committed
194
  at::optional<at::Tensor> arg_out = at::nullopt;
rusty1s's avatar
rusty1s committed
195
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
196
197
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
rusty1s's avatar
rusty1s committed
198
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
199
200
  }

rusty1s's avatar
rusty1s committed
201
202
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
203
  auto E = src.size(reduce_dim);
rusty1s's avatar
rusty1s committed
204

rusty1s's avatar
rusty1s committed
205
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
206
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
207
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
208
209
210
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
211
212
213
214
215
216
217
218
219
220
221
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (K == 1) {
        segment_csr_kernel<scalar_t, REDUCE, 1>
            <<<BLOCKS(32, N), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, E);
      } else {
        segment_csr_broadcast_kernel<scalar_t, REDUCE>
            <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, K, E);
      }
    });
rusty1s's avatar
rusty1s committed
222
223
  });

rusty1s's avatar
rusty1s committed
224
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
225
226
}

rusty1s's avatar
rusty1s committed
227
228
229
230
231
template <typename scalar_t, ReductionType REDUCE>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> index_info,
                   scalar_t *out_data, int64_t *arg_out_data, size_t E) {
rusty1s's avatar
rusty1s committed
232

rusty1s's avatar
rusty1s committed
233
234
235
236
237
238
  // Each thread processes exactly one entry. Within a warp, we perform a
  // parallel reduction across equal indices, and write the intermediate
  // result via atomics.

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = row_idx & (32 - 1);
rusty1s's avatar
rusty1s committed
239

rusty1s's avatar
rusty1s committed
240
241
242
243
244
  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int idx = index_info.data[offset], next_idx;
    scalar_t val = src_data[row_idx], tmp;
rusty1s's avatar
rusty1s committed
245
246

#pragma unroll
rusty1s's avatar
rusty1s committed
247
248
249
    for (int i = 1; i < 32; i *= 2) {
      tmp = __shfl_up_sync(FULL_MASK, val, i);
      next_idx = __shfl_up_sync(FULL_MASK, idx, i);
250
      assert(idx >= next_idx);
rusty1s's avatar
rusty1s committed
251
      if (lane_idx >= i && idx == next_idx)
rusty1s's avatar
rusty1s committed
252
253
254
        val += tmp;
    }

rusty1s's avatar
rusty1s committed
255
256
    next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
    if (lane_idx == 32 - 1 || idx != next_idx) {
rusty1s's avatar
rusty1s committed
257
258
259
260
261
      atomAdd(out_data + idx, val);
    }
  }
}

rusty1s's avatar
rusty1s committed
262
263
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
rusty1s's avatar
rusty1s committed
264
265
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
rusty1s's avatar
rusty1s committed
266
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) {
rusty1s's avatar
rusty1s committed
267

rusty1s's avatar
rusty1s committed
268
269
270
  // Each thread processes a single column and `TB` index entries. Coalesced
  // read and write is performed in column-major order. The intermediate
  // results are written via atomics.
rusty1s's avatar
rusty1s committed
271

rusty1s's avatar
rusty1s committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
  int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
  int col_idx = blockIdx.y * blockDim.x + threadIdx.x;

  if (row_start < E && col_idx < K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_start, index_info);

    int idx1 = __ldg(index_info.data + offset);
    scalar_t val = src_data[K * row_start + col_idx];

#pragma unroll
    for (int i = 1; i < TB; i++) {
      if (row_start + i >= E)
        break;

      int idx2 = __ldg(index_info.data + offset +
                       i * index_info.strides[index_info.dims - 1]);
289
      assert(idx1 <= idx2);
rusty1s's avatar
rusty1s committed
290
291
292
293
294
295
296
297
298
299
300
      if (idx1 == idx2) {
        val += src_data[K * (row_start + i) + col_idx];
      } else {
        atomAdd(out_data + K * idx1 + col_idx, val);
        val = src_data[K * (row_start + i) + col_idx];
      }
      idx1 = idx2;
    }

    atomAdd(out_data + K * idx1 + col_idx, val);
  }
rusty1s's avatar
rusty1s committed
301
302
}

rusty1s's avatar
rusty1s committed
303
304
305
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                 std::string reduce) {
rusty1s's avatar
rusty1s committed
306
307
308
309
310
  AT_ASSERTM(src.dim() >= index.dim());
  for (int i = 0; i < index.dim(); i++)
    AT_ASSERTM(src.size(i) == index.size(i));

  src = src.contiguous();
rusty1s's avatar
rusty1s committed
311
  out = out.contiguous();
rusty1s's avatar
rusty1s committed
312
  auto reduce_dim = index.dim() - 1;
rusty1s's avatar
rusty1s committed
313

rusty1s's avatar
rusty1s committed
314
315
316
  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
      AT_ASSERTM(src.size(i) == out.size(i));
rusty1s's avatar
rusty1s committed
317

rusty1s's avatar
rusty1s committed
318
319
320
321
322
  at::optional<at::Tensor> arg_out = at::nullopt;
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), index.options());
  }

rusty1s's avatar
rusty1s committed
323
324
  auto E = index.numel();
  auto K = src.numel() / index.numel();
rusty1s's avatar
rusty1s committed
325
  auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
326

rusty1s's avatar
rusty1s committed
327
328
  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
329
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
rusty1s's avatar
rusty1s committed
330
331
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
332

rusty1s's avatar
rusty1s committed
333
334
335
    // Select the right kernel based on average row length (purely heuristic)
    // and whether we need broadcasting capabilties (K > 1):

rusty1s's avatar
rusty1s committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    if (K == 1 && reduce == "add") {
      segment_coo_kernel<scalar_t, ADD><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, nullptr, E);
    } else if (K == 1 && reduce == "mean") {
      segment_coo_kernel<scalar_t, MEAN><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, nullptr, E);
    } else if (K == 1 && reduce == "min") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_coo_kernel<scalar_t, MIN><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, arg_out_data, E);
    } else if (K == 1 && reduce == "max") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_coo_kernel<scalar_t, MAX><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, arg_out_data, E);
    } else if (avg_len <= 8)
      segment_coo_broadcast_kernel<scalar_t, ADD, 4>
rusty1s's avatar
rusty1s committed
352
          <<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0,
rusty1s's avatar
rusty1s committed
353
354
             stream>>>(src_data, index_info, out_data, nullptr, E, K);
    else if (avg_len <= 16)
rusty1s's avatar
rusty1s committed
355
      segment_coo_broadcast_kernel<scalar_t, ADD, 8>
rusty1s's avatar
rusty1s committed
356
          <<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0,
rusty1s's avatar
rusty1s committed
357
358
             stream>>>(src_data, index_info, out_data, nullptr, E, K);
    else if (avg_len <= 32)
rusty1s's avatar
rusty1s committed
359
      segment_coo_broadcast_kernel<scalar_t, ADD, 16>
rusty1s's avatar
rusty1s committed
360
          <<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
rusty1s's avatar
rusty1s committed
361
             0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
rusty1s's avatar
rusty1s committed
362
    else
rusty1s's avatar
rusty1s committed
363
      segment_coo_broadcast_kernel<scalar_t, ADD, 32>
rusty1s's avatar
rusty1s committed
364
          <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
rusty1s's avatar
rusty1s committed
365
             0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
rusty1s's avatar
rusty1s committed
366
  });
367

rusty1s's avatar
rusty1s committed
368
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
369
}