"tests/vscode:/vscode.git/clone" did not exist on "994360f7a57e3f10e961cc8ab3bbcad2fa37634b"
segment_kernel.cu 17.7 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
7
#include "compat.cuh"
rusty1s's avatar
rusty1s committed
8
#include "indptr.cuh"
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
#define THREADS 256
rusty1s's avatar
rusty1s committed
11
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
12
13
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
14
15
16
17
18
enum ReductionType { SUM, MEAN, MIN, MAX };

const std::map<std::string, ReductionType> reduce2REDUCE = {
    {"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
21
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
rusty1s's avatar
rusty1s committed
22
23
24
    switch (reduce2REDUCE.at(reduce)) {                                        \
    case SUM: {                                                                \
      const ReductionType REDUCE = SUM;                                        \
rusty1s's avatar
rusty1s committed
25
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
26
27
    }                                                                          \
    case MEAN: {                                                               \
rusty1s's avatar
rusty1s committed
28
29
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
30
31
    }                                                                          \
    case MIN: {                                                                \
rusty1s's avatar
rusty1s committed
32
33
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
rusty1s's avatar
rusty1s committed
34
35
    }                                                                          \
    case MAX: {                                                                \
rusty1s's avatar
rusty1s committed
36
37
38
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
rusty1s's avatar
rusty1s committed
39
    }                                                                          \
rusty1s's avatar
rusty1s committed
40
41
42
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
rusty1s's avatar
rusty1s committed
43
  static inline __host__ __device__ scalar_t init() {
rusty1s's avatar
rusty1s committed
44
45
46
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
rusty1s's avatar
rusty1s committed
47
      return std::numeric_limits<scalar_t>::lowest();
rusty1s's avatar
rusty1s committed
48
49
50
51
52
    } else {
      return (scalar_t)0;
    }
  }

rusty1s's avatar
rusty1s committed
53
54
  static inline __host__ __device__ void update(scalar_t *val,
                                                scalar_t new_val) {
rusty1s's avatar
rusty1s committed
55
    if (REDUCE == SUM || REDUCE == MEAN) {
rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
    }
  }

rusty1s's avatar
rusty1s committed
63
64
  static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
                                                int64_t *arg, int64_t new_arg) {
rusty1s's avatar
rusty1s committed
65
    if (REDUCE == SUM || REDUCE == MEAN) {
rusty1s's avatar
rusty1s committed
66
67
68
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
rusty1s's avatar
rusty1s committed
69
70
71
72
73
      *val = new_val;
      *arg = new_arg;
    }
  }

rusty1s's avatar
rusty1s committed
74
75
76
  static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
                                               int64_t *arg_address,
                                               int64_t arg, int count) {
rusty1s's avatar
rusty1s committed
77
    if (REDUCE == SUM) {
rusty1s's avatar
rusty1s committed
78
79
80
81
82
83
84
85
86
87
88
89
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (scalar_t)max(count, 1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
rusty1s's avatar
atomics  
rusty1s committed
90

rusty1s's avatar
rusty1s committed
91
  static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
rusty1s's avatar
rusty1s committed
92
    if (REDUCE == SUM || REDUCE == MEAN) {
rusty1s's avatar
atomics  
rusty1s committed
93
94
95
96
97
98
99
      atomAdd(address, val);
    } else if (REDUCE == MIN && val < *address) {
      atomMin(address, val);
    } else if (REDUCE == MAX && val > *address) {
      atomMax(address, val);
    }
  }
rusty1s's avatar
rusty1s committed
100
};
rusty1s's avatar
rusty1s committed
101

rusty1s's avatar
rusty1s committed
102
103
104
105
106
107
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
                   scalar_t *out_data, int64_t *arg_out_data, size_t N,
                   size_t E) {
rusty1s's avatar
rusty1s committed
108

rusty1s's avatar
atomics  
rusty1s committed
109
110
  // Each warp processes exactly `32/TB` rows and aggregates all row values
  // via a parallel reduction.
rusty1s's avatar
rusty1s committed
111

rusty1s's avatar
rusty1s committed
112
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
113
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
114
115
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
116
  if (row_idx < N) {
rusty1s's avatar
rusty1s committed
117
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
118
119
120
    int64_t row_start = __ldg(indptr_info.data + offset);
    int64_t row_end = __ldg(indptr_info.data + offset +
                            indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
121

rusty1s's avatar
rusty1s committed
122
    scalar_t val = Reducer<scalar_t, REDUCE>::init();
rusty1s's avatar
atomics  
rusty1s committed
123
    int64_t arg, arg_tmp;
rusty1s's avatar
rusty1s committed
124

rusty1s's avatar
rusty1s committed
125
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
126
127
    for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
         src_idx += TB) {
rusty1s's avatar
rusty1s committed
128
129
      Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
                                        src_idx);
rusty1s's avatar
rusty1s committed
130
131
132
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
133
134
    for (int i = TB / 2; i > 0; i /= 2) {
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
135
      if (REDUCE == MIN || REDUCE == MAX)
rusty1s's avatar
atomics  
rusty1s committed
136
        arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
rusty1s's avatar
rusty1s committed
137
138
      Reducer<scalar_t, REDUCE>::update(
          &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
rusty1s's avatar
rusty1s committed
139
    }
rusty1s's avatar
rusty1s committed
140
141

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
142
143
144
      Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
                                       arg_out_data + row_idx, arg,
                                       row_end - row_start);
rusty1s's avatar
rusty1s committed
145
146
147
148
    }
  }
}

rusty1s's avatar
rusty1s committed
149
150
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
rusty1s's avatar
rusty1s committed
151
152
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
rusty1s's avatar
rusty1s committed
153
    scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
157
  // Each thread processes exactly one row. It turned out that is more
  // efficient than using shared memory due to avoiding synchronization
  // barriers.
rusty1s's avatar
rusty1s committed
158

rusty1s's avatar
rusty1s committed
159
160
161
162
163
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
rusty1s's avatar
rusty1s committed
164
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
165
166
167
    int64_t row_start = __ldg(indptr_info.data + offset);
    int64_t row_end = __ldg(indptr_info.data + offset +
                            indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
168

rusty1s's avatar
rusty1s committed
169
170
    scalar_t val = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;
rusty1s's avatar
rusty1s committed
171
172

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
rusty1s's avatar
rusty1s committed
173
    for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
rusty1s's avatar
rusty1s committed
174
175
      Reducer<scalar_t, REDUCE>::update(
          &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
rusty1s's avatar
rusty1s committed
176
177
    }

rusty1s's avatar
rusty1s committed
178
179
180
    Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
                                     arg_out_data + thread_idx, arg,
                                     row_end - row_start);
rusty1s's avatar
rusty1s committed
181
182
183
  }
}

rusty1s's avatar
rusty1s committed
184
185
186
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
                 at::optional<at::Tensor> out_opt, std::string reduce) {
187

rusty1s's avatar
rusty1s committed
188
  AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
rusty1s's avatar
rusty1s committed
189

rusty1s's avatar
rusty1s committed
190
  // Broadcasting `indptr` via `expand`.
rusty1s's avatar
rusty1s committed
191
192
193
194
195
  auto sizes = indptr.sizes().vec();
  for (int i = 0; i < indptr.dim() - 1; i++) {
    sizes[i] = src.size(i);
  }
  indptr = indptr.expand(sizes);
rusty1s's avatar
rusty1s committed
196

rusty1s's avatar
rusty1s committed
197
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
198
  auto reduce_dim = indptr.dim() - 1;
199
200
201

  at::Tensor out;
  if (out_opt.has_value()) {
rusty1s's avatar
rusty1s committed
202
    out = out_opt.value().contiguous();
203
204
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
rusty1s's avatar
rusty1s committed
205
        AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
rusty1s's avatar
rusty1s committed
206
207
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
               "Input mismatch");
208
  } else {
rusty1s's avatar
rusty1s committed
209
    sizes = src.sizes().vec();
210
211
212
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
    out = at::empty(sizes, src.options());
  }
rusty1s's avatar
rusty1s committed
213

rusty1s's avatar
rusty1s committed
214
  at::optional<at::Tensor> arg_out = at::nullopt;
rusty1s's avatar
rusty1s committed
215
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
216
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
217
    arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
rusty1s's avatar
rusty1s committed
218
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
219
220
  }

rusty1s's avatar
rusty1s committed
221
222
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
223
  auto E = src.size(reduce_dim);
rusty1s's avatar
rusty1s committed
224

rusty1s's avatar
rusty1s committed
225
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
226
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
227
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
228
229
230
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
231
232
233
234
235
236
237
238
239
240
241
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (K == 1) {
        segment_csr_kernel<scalar_t, REDUCE, 1>
            <<<BLOCKS(32, N), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, E);
      } else {
        segment_csr_broadcast_kernel<scalar_t, REDUCE>
            <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, K, E);
      }
    });
rusty1s's avatar
rusty1s committed
242
243
  });

rusty1s's avatar
rusty1s committed
244
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
245
246
}

rusty1s's avatar
rusty1s committed
247
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
rusty1s's avatar
rusty1s committed
248
249
250
__global__ void
segment_coo_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> index_info,
rusty1s's avatar
rusty1s committed
251
                   scalar_t *out_data, size_t E, size_t N) {
rusty1s's avatar
rusty1s committed
252

rusty1s's avatar
rusty1s committed
253
254
255
256
257
258
  // Each thread processes exactly one entry. Within a warp, we perform a
  // parallel reduction across equal indices, and write the intermediate
  // result via atomics.

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = row_idx & (32 - 1);
rusty1s's avatar
rusty1s committed
259
  int D = index_info.sizes[index_info.dims - 1];
rusty1s's avatar
rusty1s committed
260

rusty1s's avatar
rusty1s committed
261
262
263
  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
rusty1s's avatar
rusty1s committed
264
    int64_t idx = index_info.data[offset], next_idx;
rusty1s's avatar
rusty1s committed
265
    int out_idx = (row_idx / D) * N + idx;
rusty1s's avatar
atomics  
rusty1s committed
266

rusty1s's avatar
rusty1s committed
267
    scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
rusty1s's avatar
rusty1s committed
268
269

#pragma unroll
rusty1s's avatar
rusty1s committed
270
    for (int i = 1; i < 32; i *= 2) {
rusty1s's avatar
atomics  
rusty1s committed
271
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
272
273
      tmp = __shfl_up_sync(FULL_MASK, val, i);
      next_idx = __shfl_up_sync(FULL_MASK, idx, i);
rusty1s's avatar
rusty1s committed
274
275
276
277
278
      if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
        assert(idx >= next_idx);
        if (idx == next_idx)
          Reducer<scalar_t, REDUCE>::update(&val, tmp);
      }
rusty1s's avatar
rusty1s committed
279
280
    }

rusty1s's avatar
rusty1s committed
281
    next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
rusty1s's avatar
rusty1s committed
282
283
    if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
        idx != next_idx)
rusty1s's avatar
rusty1s committed
284
      Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
rusty1s's avatar
rusty1s committed
285
286
287
  }
}

rusty1s's avatar
rusty1s committed
288
289
290
291
292
293
294
template <typename scalar_t>
__global__ void segment_coo_arg_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
295
  int D = index_info.sizes[index_info.dims - 1];
rusty1s's avatar
rusty1s committed
296
297
298
299

  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
rusty1s's avatar
rusty1s committed
300
    int64_t idx = index_info.data[offset];
rusty1s's avatar
rusty1s committed
301
    int out_idx = (row_idx / D) * N + idx;
rusty1s's avatar
rusty1s committed
302
303
304

    scalar_t val = __ldg(out_data + out_idx);
    if (src_data[row_idx] == val)
rusty1s's avatar
rusty1s committed
305
      arg_out_data[out_idx] = row_idx % D;
rusty1s's avatar
rusty1s committed
306
307
308
  }
}

rusty1s's avatar
rusty1s committed
309
310
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
rusty1s's avatar
rusty1s committed
311
312
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
rusty1s's avatar
rusty1s committed
313
    scalar_t *out_data, size_t E, size_t K, size_t N) {
rusty1s's avatar
rusty1s committed
314

rusty1s's avatar
rusty1s committed
315
316
317
  // Each thread processes a single column and `TB` index entries. Coalesced
  // read and write is performed in column-major order. The intermediate
  // results are written via atomics.
rusty1s's avatar
rusty1s committed
318

rusty1s's avatar
rusty1s committed
319
320
  int D = index_info.sizes[index_info.dims - 1];
  int E_1 = E / D;
rusty1s's avatar
rusty1s committed
321
  int E_2 = D + TB - (D % TB);
rusty1s's avatar
rusty1s committed
322
323

  int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
rusty1s's avatar
rusty1s committed
324
325
  int col_idx = blockIdx.y * blockDim.x + threadIdx.x;

rusty1s's avatar
rusty1s committed
326
327
328
329
  int dim_start = (row_idx * TB) / E_2;
  int row_start = (row_idx * TB) % E_2;

  if (dim_start < E_1 && col_idx < K) {
rusty1s's avatar
rusty1s committed
330

rusty1s's avatar
rusty1s committed
331
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
rusty1s's avatar
rusty1s committed
332
333
        dim_start * D + row_start, index_info);
    int idx1 = __ldg(index_info.data + offset), idx2;
rusty1s's avatar
rusty1s committed
334

rusty1s's avatar
rusty1s committed
335
    scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx];
rusty1s's avatar
rusty1s committed
336
337
338

#pragma unroll
    for (int i = 1; i < TB; i++) {
rusty1s's avatar
rusty1s committed
339
      if (row_start + i >= D)
rusty1s's avatar
rusty1s committed
340
341
        break;

rusty1s's avatar
rusty1s committed
342
343
      idx2 = __ldg(index_info.data + offset +
                   i * index_info.strides[index_info.dims - 1]);
344
      assert(idx1 <= idx2);
rusty1s's avatar
rusty1s committed
345
      if (idx1 == idx2) {
rusty1s's avatar
rusty1s committed
346
        Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
rusty1s committed
347
            &val, src_data[K * (dim_start * D + row_start + i) + col_idx]);
rusty1s's avatar
rusty1s committed
348
      } else {
rusty1s's avatar
rusty1s committed
349
        Reducer<scalar_t, REDUCE>::atomic_write(
rusty1s's avatar
rusty1s committed
350
351
            out_data + (dim_start * N + idx1) * K + col_idx, val);
        val = src_data[K * (dim_start * D + row_start + i) + col_idx];
rusty1s's avatar
rusty1s committed
352
      }
rusty1s's avatar
rusty1s committed
353

rusty1s's avatar
rusty1s committed
354
355
356
      idx1 = idx2;
    }

rusty1s's avatar
rusty1s committed
357
    Reducer<scalar_t, REDUCE>::atomic_write(
rusty1s's avatar
rusty1s committed
358
        out_data + (dim_start * N + idx1) * K + col_idx, val);
rusty1s's avatar
rusty1s committed
359
360
361
362
363
364
365
366
367
368
369
370
  }
}

template <typename scalar_t>
__global__ void segment_coo_arg_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int col_idx = thread_idx % K;
rusty1s's avatar
rusty1s committed
371
  int D = index_info.sizes[index_info.dims - 1];
rusty1s's avatar
rusty1s committed
372
373
374
375
376

  if (row_idx < E && col_idx < K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int idx = __ldg(index_info.data + offset);
rusty1s's avatar
rusty1s committed
377
    int out_idx = ((row_idx / D) * N + idx) * K + col_idx;
rusty1s's avatar
rusty1s committed
378
379
380

    scalar_t val = __ldg(out_data + out_idx);
    if (src_data[thread_idx] == val)
rusty1s's avatar
rusty1s committed
381
      arg_out_data[out_idx] = row_idx % D;
rusty1s's avatar
rusty1s committed
382
  }
rusty1s's avatar
rusty1s committed
383
384
}

rusty1s's avatar
rusty1s committed
385
386
387
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                 std::string reduce) {
rusty1s's avatar
rusty1s committed
388

rusty1s's avatar
rusty1s committed
389
  AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
rusty1s's avatar
rusty1s committed
390

rusty1s's avatar
rusty1s committed
391
  // Broadcasting `index` via `expand`.
rusty1s's avatar
rusty1s committed
392
393
394
395
396
  auto sizes = index.sizes().vec();
  for (int i = 0; i < index.dim(); i++) {
    sizes[i] = src.size(i);
  }
  index = index.expand(sizes);
rusty1s's avatar
rusty1s committed
397
398

  src = src.contiguous();
rusty1s's avatar
rusty1s committed
399
  out = out.contiguous();
rusty1s's avatar
rusty1s committed
400
  auto reduce_dim = index.dim() - 1;
rusty1s's avatar
rusty1s committed
401

rusty1s's avatar
rusty1s committed
402
403
  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
rusty1s's avatar
rusty1s committed
404
      AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
rusty1s's avatar
rusty1s committed
405

rusty1s's avatar
rusty1s committed
406
  at::optional<at::Tensor> arg_out = at::nullopt;
rusty1s's avatar
rusty1s committed
407
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
408
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
rusty1s's avatar
rusty1s committed
409
    arg_out = at::full_like(out, src.size(reduce_dim), index.options());
rusty1s's avatar
rusty1s committed
410
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
411
412
  }

rusty1s's avatar
rusty1s committed
413
  auto E = index.numel();
rusty1s's avatar
rusty1s committed
414
415
  auto E_2 = index.size(reduce_dim);
  auto E_1 = index.numel() / E_2;
rusty1s's avatar
rusty1s committed
416
  auto K = src.numel() / E;
rusty1s's avatar
rusty1s committed
417
  auto N = out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
418
  auto avg_len = (float)E_2 / (float)N;
rusty1s's avatar
rusty1s committed
419

rusty1s's avatar
rusty1s committed
420
421
  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
422
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
rusty1s's avatar
rusty1s committed
423
424
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
425

rusty1s's avatar
rusty1s committed
426
427
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (K == 1) {
rusty1s's avatar
rusty1s committed
428
        segment_coo_kernel<scalar_t, REDUCE, true>
rusty1s's avatar
rusty1s committed
429
            <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
rusty1s's avatar
rusty1s committed
430
                                                   out_data, E, N);
rusty1s's avatar
rusty1s committed
431
432
      } else if (avg_len <= 8) {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
rusty1s's avatar
rusty1s committed
433
434
435
            <<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);
rusty1s's avatar
rusty1s committed
436
437
      } else if (avg_len <= 16) {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
rusty1s's avatar
rusty1s committed
438
439
440
            <<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);
rusty1s's avatar
rusty1s committed
441
442
      } else if (avg_len <= 32) {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
rusty1s's avatar
rusty1s committed
443
            <<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
rusty1s's avatar
rusty1s committed
444
445
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);
rusty1s's avatar
rusty1s committed
446
447
      } else {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
rusty1s's avatar
rusty1s committed
448
            <<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
rusty1s's avatar
rusty1s committed
449
450
451
452
453
454
455
456
457
458
459
460
461
462
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);
      }

      if (REDUCE == MIN || REDUCE == MAX) {
        if (K == 1) {
          segment_coo_arg_kernel<scalar_t>
              <<<BLOCKS(1, E), THREADS, 0, stream>>>(
                  src_data, index_info, out_data, arg_out_data, E, N);
        } else {
          segment_coo_arg_broadcast_kernel<scalar_t>
              <<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
                  src_data, index_info, out_data, arg_out_data, E, K, N);
        }
rusty1s's avatar
rusty1s committed
463
464
      }
    });
rusty1s's avatar
rusty1s committed
465
  });
466

rusty1s's avatar
rusty1s committed
467
  if (reduce2REDUCE.at(reduce) == MEAN) {
rusty1s's avatar
rusty1s committed
468
469
470
471
    auto sizes = index.sizes().vec();
    sizes[reduce_dim] = out.size(reduce_dim);
    auto count = at::zeros(sizes, out.options());

rusty1s's avatar
rusty1s committed
472
473
    AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
      auto count_data = count.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
474
      segment_coo_kernel<scalar_t, SUM, false>
rusty1s's avatar
rusty1s committed
475
          <<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
rusty1s's avatar
rusty1s committed
476
                                                 count_data, E, N);
rusty1s's avatar
rusty1s committed
477
    });
rusty1s's avatar
rusty1s committed
478
479

    count.clamp_(1);
rusty1s's avatar
rusty1s committed
480
    arg_out = count;
rusty1s's avatar
rusty1s committed
481
482
483
484
485
486

    for (int i = reduce_dim + 1; i < out.dim(); i++) {
      count = count.unsqueeze(-1);
    }

    out.div_(count);
rusty1s's avatar
atomics  
rusty1s committed
487
488
  }

rusty1s's avatar
rusty1s committed
489
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
490
}