segment_kernel.cu 15.3 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
7
#include "compat.cuh"
rusty1s's avatar
rusty1s committed
8
#include "indptr.cuh"
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
#define THREADS 256
rusty1s's avatar
rusty1s committed
11
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
12
13
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    if (reduce == "add") {                                                     \
      const ReductionType REDUCE = ADD;                                        \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "mean") {                                             \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "min") {                                              \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "max") {                                              \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline __host__ __device__ scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::min();
    } else {
      return (scalar_t)0;
    }
  }

  static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
                                                int64_t *arg, int64_t new_arg) {
rusty1s's avatar
rusty1s committed
45
46
47
48
    if (REDUCE == ADD || REDUCE == MEAN) {
      *val = *val + new_val;
    } else if ((REDUCE == MIN && new_val < *val) ||
               (REDUCE == MAX && new_val > *val)) {
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
      *val = new_val;
      *arg = new_arg;
    }
  }

  static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
                                               int64_t *arg_address,
                                               int64_t arg, int count) {
    if (REDUCE == ADD) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (scalar_t)max(count, 1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
rusty1s's avatar
atomics  
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
72
73
  static inline __device__ void atomic_write(scalar_t *address, scalar_t val,
                                             int64_t *arg_address,
                                             int64_t arg) {
rusty1s's avatar
atomics  
rusty1s committed
74
75
76
77
78
79
80
81
82
83
84
    if (REDUCE == ADD) {
      atomAdd(address, val);
    } else if (REDUCE == MEAN) {
      atomAdd(address, val);
    } else if (REDUCE == MIN && val < *address) {
      atomMin(address, val);
    } else if (REDUCE == MAX && val > *address) {
      atomMax(address, val);
    }

    if (REDUCE == MIN || REDUCE == MAX) {
rusty1s's avatar
rusty1s committed
85
      assert(false); // TODO
rusty1s's avatar
atomics  
rusty1s committed
86
87
88
89
90
91
      __syncthreads();
      if (*address == val) {
        *arg_address = arg;
      }
    }
  }
rusty1s's avatar
rusty1s committed
92
};
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
rusty1s committed
94
95
96
97
98
99
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
                   scalar_t *out_data, int64_t *arg_out_data, size_t N,
                   size_t E) {
rusty1s's avatar
rusty1s committed
100

rusty1s's avatar
atomics  
rusty1s committed
101
102
  // Each warp processes exactly `32/TB` rows and aggregates all row values
  // via a parallel reduction.
rusty1s's avatar
rusty1s committed
103

rusty1s's avatar
rusty1s committed
104
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
105
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
106
107
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
108
  if (row_idx < N) {
rusty1s's avatar
rusty1s committed
109
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
110
    int row_start = __ldg(indptr_info.data + offset);
rusty1s's avatar
rusty1s committed
111
112
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
    scalar_t val = Reducer<scalar_t, REDUCE>::init(), tmp;
rusty1s's avatar
atomics  
rusty1s committed
115
    int64_t arg, arg_tmp;
rusty1s's avatar
rusty1s committed
116

rusty1s's avatar
rusty1s committed
117
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
118
    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
rusty1s's avatar
rusty1s committed
119
120
      Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
                                        src_idx);
rusty1s's avatar
rusty1s committed
121
122
123
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
124
125
    for (int i = TB / 2; i > 0; i /= 2) {
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
126
      if (REDUCE == MIN || REDUCE == MAX) {
rusty1s's avatar
rusty1s committed
127
        tmp = __shfl_down_sync(FULL_MASK, val, i);
rusty1s's avatar
atomics  
rusty1s committed
128
        arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
rusty1s's avatar
rusty1s committed
129
130
131
132
133
        if (row_start + lane_idx + i < row_end)
          Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
      } else {
        Reducer<scalar_t, REDUCE>::update(
            &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
rusty1s's avatar
rusty1s committed
134
      }
rusty1s's avatar
rusty1s committed
135
    }
rusty1s's avatar
rusty1s committed
136
137

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
138
139
140
      Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
                                       arg_out_data + row_idx, arg,
                                       row_end - row_start);
rusty1s's avatar
rusty1s committed
141
142
143
144
    }
  }
}

rusty1s's avatar
rusty1s committed
145
146
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
rusty1s's avatar
rusty1s committed
147
148
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
rusty1s's avatar
rusty1s committed
149
    scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
rusty1s's avatar
rusty1s committed
150

rusty1s's avatar
rusty1s committed
151
152
153
  // Each thread processes exactly one row. It turned out that is more
  // efficient than using shared memory due to avoiding synchronization
  // barriers.
rusty1s's avatar
rusty1s committed
154

rusty1s's avatar
rusty1s committed
155
156
157
158
159
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
rusty1s's avatar
rusty1s committed
160
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
161
162
163
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
164

rusty1s's avatar
rusty1s committed
165
166
    scalar_t val = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;
rusty1s's avatar
rusty1s committed
167
168
169

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int src_idx = row_start; src_idx < row_end; src_idx++) {
rusty1s's avatar
rusty1s committed
170
171
      Reducer<scalar_t, REDUCE>::update(
          &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
rusty1s's avatar
rusty1s committed
172
173
    }

rusty1s's avatar
rusty1s committed
174
175
176
    Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
                                     arg_out_data + thread_idx, arg,
                                     row_end - row_start);
rusty1s's avatar
rusty1s committed
177
178
179
  }
}

rusty1s's avatar
rusty1s committed
180
181
182
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
                 at::optional<at::Tensor> out_opt, std::string reduce) {
183

rusty1s's avatar
rusty1s committed
184
  AT_ASSERTM(src.dim() >= indptr.dim());
rusty1s's avatar
rusty1s committed
185
186
187
  for (int i = 0; i < indptr.dim() - 1; i++)
    AT_ASSERTM(src.size(i) == indptr.size(i));

rusty1s's avatar
rusty1s committed
188
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
189
  auto reduce_dim = indptr.dim() - 1;
190
191
192

  at::Tensor out;
  if (out_opt.has_value()) {
rusty1s's avatar
rusty1s committed
193
    out = out_opt.value().contiguous();
194
195
196
197
198
199
200
201
202
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
        AT_ASSERTM(src.size(i) == out.size(i));
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1);
  } else {
    auto sizes = src.sizes().vec();
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
    out = at::empty(sizes, src.options());
  }
rusty1s's avatar
rusty1s committed
203

rusty1s's avatar
rusty1s committed
204
  at::optional<at::Tensor> arg_out = at::nullopt;
rusty1s's avatar
rusty1s committed
205
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
206
207
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
rusty1s's avatar
rusty1s committed
208
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
209
210
  }

rusty1s's avatar
rusty1s committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
  if (reduce == "any") {
    auto index = indptr.narrow(reduce_dim, 0, indptr.size(reduce_dim) - 1);
    auto index2 = indptr.narrow(reduce_dim, 1, indptr.size(reduce_dim) - 1);
    auto mask = (index2 - index) == 0;

    for (int i = reduce_dim + 1; i < src.dim(); i++) {
      index = index.unsqueeze(-1);
      mask = mask.unsqueeze(-1);
    }

    at::gather_out(out, src, reduce_dim, index.expand(out.sizes()));
    out.masked_fill_(mask.expand(out.sizes()), 0);

    return std::make_tuple(out, arg_out);
  }

rusty1s's avatar
rusty1s committed
227
228
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
229
  auto E = src.size(reduce_dim);
rusty1s's avatar
rusty1s committed
230

rusty1s's avatar
rusty1s committed
231
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
232
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
233
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
234
235
236
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
237
238
239
240
241
242
243
244
245
246
247
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (K == 1) {
        segment_csr_kernel<scalar_t, REDUCE, 1>
            <<<BLOCKS(32, N), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, E);
      } else {
        segment_csr_broadcast_kernel<scalar_t, REDUCE>
            <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, K, E);
      }
    });
rusty1s's avatar
rusty1s committed
248
249
  });

rusty1s's avatar
rusty1s committed
250
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
251
252
}

rusty1s's avatar
rusty1s committed
253
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
rusty1s's avatar
rusty1s committed
254
255
256
257
__global__ void
segment_coo_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> index_info,
                   scalar_t *out_data, int64_t *arg_out_data, size_t E) {
rusty1s's avatar
rusty1s committed
258

rusty1s's avatar
rusty1s committed
259
260
261
262
263
264
  // Each thread processes exactly one entry. Within a warp, we perform a
  // parallel reduction across equal indices, and write the intermediate
  // result via atomics.

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = row_idx & (32 - 1);
rusty1s's avatar
rusty1s committed
265

rusty1s's avatar
rusty1s committed
266
267
268
269
  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int idx = index_info.data[offset], next_idx;
rusty1s's avatar
atomics  
rusty1s committed
270

rusty1s's avatar
rusty1s committed
271
272
273
274
275
276
    scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
    int64_t arg, arg_tmp;

    if (REDUCE == MIN || REDUCE == MAX) {
      arg = row_idx % index_info.sizes[index_info.dims - 1];
    }
rusty1s's avatar
rusty1s committed
277
278

#pragma unroll
rusty1s's avatar
rusty1s committed
279
    for (int i = 1; i < 32; i *= 2) {
rusty1s's avatar
atomics  
rusty1s committed
280
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
281
      tmp = __shfl_up_sync(FULL_MASK, val, i);
rusty1s's avatar
atomics  
rusty1s committed
282
283
284
      if (REDUCE == MIN || REDUCE == MAX) {
        arg_tmp = __shfl_up_sync(FULL_MASK, arg, i);
      }
rusty1s's avatar
rusty1s committed
285
      next_idx = __shfl_up_sync(FULL_MASK, idx, i);
286
      assert(idx >= next_idx);
rusty1s's avatar
rusty1s committed
287
      if (lane_idx >= i && idx == next_idx)
rusty1s's avatar
atomics  
rusty1s committed
288
        Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
rusty1s's avatar
rusty1s committed
289
290
    }

rusty1s's avatar
rusty1s committed
291
292
    next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
    if (lane_idx == 32 - 1 || idx != next_idx) {
rusty1s's avatar
rusty1s committed
293
294
      Reducer<scalar_t, REDUCE>::atomic_write(out_data + idx, val,
                                              arg_out_data + idx, arg);
rusty1s's avatar
rusty1s committed
295
296
297
298
    }
  }
}

rusty1s's avatar
rusty1s committed
299
300
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
rusty1s's avatar
rusty1s committed
301
302
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
rusty1s's avatar
rusty1s committed
303
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) {
rusty1s's avatar
rusty1s committed
304

rusty1s's avatar
rusty1s committed
305
306
307
  // Each thread processes a single column and `TB` index entries. Coalesced
  // read and write is performed in column-major order. The intermediate
  // results are written via atomics.
rusty1s's avatar
rusty1s committed
308

rusty1s's avatar
rusty1s committed
309
  int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
rusty1s's avatar
rusty1s committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
  int col_idx = blockIdx.y * blockDim.x + threadIdx.x;

  if (row_start < E && col_idx < K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_start, index_info);

    int idx1 = __ldg(index_info.data + offset);
    scalar_t val = src_data[K * row_start + col_idx];

#pragma unroll
    for (int i = 1; i < TB; i++) {
      if (row_start + i >= E)
        break;

      int idx2 = __ldg(index_info.data + offset +
                       i * index_info.strides[index_info.dims - 1]);
326
      assert(idx1 <= idx2);
rusty1s's avatar
rusty1s committed
327
328
329
330
331
332
333
334
335
336
337
      if (idx1 == idx2) {
        val += src_data[K * (row_start + i) + col_idx];
      } else {
        atomAdd(out_data + K * idx1 + col_idx, val);
        val = src_data[K * (row_start + i) + col_idx];
      }
      idx1 = idx2;
    }

    atomAdd(out_data + K * idx1 + col_idx, val);
  }
rusty1s's avatar
rusty1s committed
338
339
}

rusty1s's avatar
rusty1s committed
340
341
342
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                 std::string reduce) {
rusty1s's avatar
rusty1s committed
343
344
345
346
347
  AT_ASSERTM(src.dim() >= index.dim());
  for (int i = 0; i < index.dim(); i++)
    AT_ASSERTM(src.size(i) == index.size(i));

  src = src.contiguous();
rusty1s's avatar
rusty1s committed
348
  out = out.contiguous();
rusty1s's avatar
rusty1s committed
349
  auto reduce_dim = index.dim() - 1;
rusty1s's avatar
rusty1s committed
350

rusty1s's avatar
rusty1s committed
351
352
353
  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
      AT_ASSERTM(src.size(i) == out.size(i));
rusty1s's avatar
rusty1s committed
354

rusty1s's avatar
rusty1s committed
355
  at::optional<at::Tensor> arg_out = at::nullopt;
rusty1s's avatar
rusty1s committed
356
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
357
358
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), index.options());
rusty1s's avatar
rusty1s committed
359
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
360
361
  }

rusty1s's avatar
rusty1s committed
362
363
364
365
366
367
368
369
  if (reduce == "any") {
    for (int i = reduce_dim + 1; i < src.dim(); i++) {
      index = index.unsqueeze(-1);
    }
    out.scatter_(reduce_dim, index.expand(src.sizes()), src);
    return std::make_tuple(out, arg_out);
  }

rusty1s's avatar
rusty1s committed
370
  auto E = index.numel();
rusty1s's avatar
rusty1s committed
371
  auto K = src.numel() / E;
rusty1s's avatar
rusty1s committed
372
  auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
373

rusty1s's avatar
rusty1s committed
374
375
  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
376
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
rusty1s's avatar
rusty1s committed
377
378
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
379

rusty1s's avatar
rusty1s committed
380
381
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (K == 1) {
rusty1s's avatar
rusty1s committed
382
        segment_coo_kernel<scalar_t, REDUCE, true>
rusty1s's avatar
rusty1s committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
                                                   out_data, arg_out_data, E);
      } else if (avg_len <= 8) {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
            <<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8),
               0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K);
      } else if (avg_len <= 16) {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
            <<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8),
               0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K);
      } else if (avg_len <= 32) {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
            <<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data,
                                         arg_out_data, E, K);
      } else {
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
            <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data,
                                         arg_out_data, E, K);
      }
    });
rusty1s's avatar
rusty1s committed
405
  });
406

rusty1s's avatar
atomics  
rusty1s committed
407
  if (reduce == "mean") {
rusty1s's avatar
rusty1s committed
408
409
410
411
    auto sizes = index.sizes().vec();
    sizes[reduce_dim] = out.size(reduce_dim);
    auto count = at::zeros(sizes, out.options());

rusty1s's avatar
rusty1s committed
412
413
    AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
      auto count_data = count.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
414
415
416
      segment_coo_kernel<scalar_t, ADD, false>
          <<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
                                                 count_data, nullptr, E);
rusty1s's avatar
rusty1s committed
417
    });
rusty1s's avatar
rusty1s committed
418
419
420

    count.clamp_(1);
    out.div_(count);
rusty1s's avatar
rusty1s committed
421
    arg_out = count;
rusty1s's avatar
atomics  
rusty1s committed
422
423
  }

rusty1s's avatar
rusty1s committed
424
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
425
}