"vscode:/vscode.git/clone" did not exist on "34242f1c48676e44ba947259e239f529137b5a09"
segment_coo_cuda.cu 13 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "segment_coo_cuda.h"

rusty1s's avatar
rusty1s committed
3
4
5
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
Jacob Zhong's avatar
Jacob Zhong committed
6
#include <type_traits>
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

#include "reducer.cuh"
#include "utils.cuh"

#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff

template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> index_info,
                   scalar_t *out_data, size_t E, size_t N) {

  // Each thread processes exactly one entry. Within a warp, we perform a
  // parallel reduction across equal indices, and write the intermediate
  // result via atomics.

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = row_idx & (32 - 1);
  int D = index_info.sizes[index_info.dims - 1];

Jacob Zhong's avatar
Jacob Zhong committed
29
30
31
32
  using cuda_scalar_t =
      typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
                                scalar_t>::type;

rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int64_t idx = index_info.data[offset], next_idx;
    int out_idx = (row_idx / D) * N + idx;

    scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;

#pragma unroll
    for (int i = 1; i < 32; i *= 2) {
      // Parallel reduction inside a single warp.
Jacob Zhong's avatar
Jacob Zhong committed
44
      tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i);
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
      next_idx = __shfl_up_sync(FULL_MASK, idx, i);
      if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
        assert(idx >= next_idx);
        if (idx == next_idx)
          Reducer<scalar_t, REDUCE>::update(&val, tmp);
      }
    }

    next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
    if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
        idx != next_idx)
      Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
  }
}

template <typename scalar_t>
__global__ void segment_coo_arg_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int D = index_info.sizes[index_info.dims - 1];

  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int64_t idx = index_info.data[offset];
    int out_idx = (row_idx / D) * N + idx;

    scalar_t val = __ldg(out_data + out_idx);
    if (src_data[row_idx] == val)
      arg_out_data[out_idx] = row_idx % D;
  }
}

template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, size_t E, size_t K, size_t N) {

  // Each thread processes a single column and `TB` index entries. Coalesced
  // read and write is performed in column-major order. The intermediate
  // results are written via atomics.

  int D = index_info.sizes[index_info.dims - 1];
  int E_1 = E / D;
rusty1s's avatar
rusty1s committed
93
  int E_2 = (D - 1) + TB - ((D - 1) % TB);
rusty1s's avatar
rusty1s committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

  int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
  int col_idx = blockIdx.y * blockDim.x + threadIdx.x;

  int dim_start = (row_idx * TB) / E_2;
  int row_start = (row_idx * TB) % E_2;

  if (dim_start < E_1 && col_idx < K) {

    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        dim_start * D + row_start, index_info);
    int idx1 = __ldg(index_info.data + offset), idx2;

    scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx];

#pragma unroll
    for (int i = 1; i < TB; i++) {
      if (row_start + i >= D)
        break;

      idx2 = __ldg(index_info.data + offset +
                   i * index_info.strides[index_info.dims - 1]);
      assert(idx1 <= idx2);
      if (idx1 == idx2) {
        Reducer<scalar_t, REDUCE>::update(
            &val, src_data[K * (dim_start * D + row_start + i) + col_idx]);
      } else {
        Reducer<scalar_t, REDUCE>::atomic_write(
            out_data + (dim_start * N + idx1) * K + col_idx, val);
        val = src_data[K * (dim_start * D + row_start + i) + col_idx];
      }

      idx1 = idx2;
    }

    Reducer<scalar_t, REDUCE>::atomic_write(
        out_data + (dim_start * N + idx1) * K + col_idx, val);
  }
}

template <typename scalar_t>
__global__ void segment_coo_arg_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int col_idx = thread_idx % K;
  int D = index_info.sizes[index_info.dims - 1];

  if (row_idx < E && col_idx < K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int idx = __ldg(index_info.data + offset);
    int out_idx = ((row_idx / D) * N + idx) * K + col_idx;

    scalar_t val = __ldg(out_data + out_idx);
    if (src_data[thread_idx] == val)
      arg_out_data[out_idx] = row_idx % D;
  }
}

rusty1s's avatar
rusty1s committed
157
158
159
160
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
                 torch::optional<torch::Tensor> optional_out,
                 torch::optional<int64_t> dim_size, std::string reduce) {
rusty1s's avatar
rusty1s committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  if (optional_out.has_value())
    CHECK_CUDA(optional_out.value());
  cudaSetDevice(src.get_device());

  CHECK_INPUT(src.dim() >= index.dim());

  auto sizes = index.sizes().vec();
  for (int i = 0; i < index.dim(); i++) {
    sizes[i] = src.size(i);
  }
  index = index.expand(sizes);

  auto dim = index.dim() - 1;

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
    for (int i = 0; i < out.dim(); i++)
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
  } else {
    sizes = src.sizes().vec();
    if (dim_size.has_value())
      sizes[dim] = dim_size.value();
rusty1s's avatar
rusty1s committed
189
190
    else if (index.numel() == 0)
      sizes[dim] = 0;
rusty1s's avatar
rusty1s committed
191
    else {
192
193
      auto tmp = index.select(dim, index.size(dim) - 1);
      tmp = tmp.numel() > 1 ? tmp.max() : tmp;
194
      sizes[dim] = 1 + tmp.cpu().data_ptr<int64_t>()[0];
rusty1s's avatar
rusty1s committed
195
196
197
198
199
200
201
202
203
    }
    out = torch::zeros(sizes, src.options());
  }

  torch::optional<torch::Tensor> arg_out = torch::nullopt;
  int64_t *arg_out_data = nullptr;
  if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
    arg_out = torch::full_like(out, src.size(dim), index.options());
    arg_out_data = arg_out.value().data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
204
205
206
207
  } else if (reduce2REDUCE.at(reduce) == MEAN) {
    auto sizes = index.sizes().vec();
    sizes[dim] = out.size(dim);
    arg_out = torch::zeros(sizes, out.options());
rusty1s's avatar
rusty1s committed
208
209
  }

rusty1s's avatar
rusty1s committed
210
211
212
  if (index.numel() == 0)
    return std::make_tuple(out, arg_out);

rusty1s's avatar
rusty1s committed
213
214
215
216
217
218
219
220
221
  auto E = index.numel();
  auto E_2 = index.size(dim);
  auto E_1 = index.numel() / E_2;
  auto K = src.numel() / E;
  auto N = out.size(dim);
  auto avg_len = (float)E_2 / (float)N;

  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
Jacob Zhong's avatar
Jacob Zhong committed
222
  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
rusty1s's avatar
rusty1s committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (!optional_out.has_value())
        out.fill_(Reducer<scalar_t, REDUCE>::init());

      if (K == 1)
        segment_coo_kernel<scalar_t, REDUCE, true>
            <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
                                                   out_data, E, N);
      else if (avg_len <= 8)
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
            <<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);
      else if (avg_len <= 16)
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
            <<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);
      else if (avg_len <= 32)
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
            <<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);
      else
        segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
            <<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
               dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
                                         N);

      if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
        out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);

      if (REDUCE == MIN || REDUCE == MAX) {
        if (K == 1)
          segment_coo_arg_kernel<scalar_t>
              <<<BLOCKS(1, E), THREADS, 0, stream>>>(
                  src_data, index_info, out_data, arg_out_data, E, N);
        else
          segment_coo_arg_broadcast_kernel<scalar_t>
              <<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
                  src_data, index_info, out_data, arg_out_data, E, K, N);
      }

      if (REDUCE == MEAN) {
rusty1s's avatar
rusty1s committed
270
        auto count_data = arg_out.value().data_ptr<scalar_t>();
rusty1s's avatar
rusty1s committed
271
272
273
        segment_coo_kernel<scalar_t, SUM, false>
            <<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
                                                   count_data, E, N);
Jacob Zhong's avatar
Jacob Zhong committed
274
275
        arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
                                     (scalar_t)1);
rusty1s's avatar
rusty1s committed
276
        auto count = arg_out.value();
rusty1s's avatar
rusty1s committed
277
278
        for (int i = dim + 1; i < out.dim(); i++)
          count = count.unsqueeze(-1);
279
        if (out.is_floating_point())
rusty1s's avatar
rusty1s committed
280
          out.true_divide_(count);
281
        else
rusty1s's avatar
rusty1s committed
282
          out.floor_divide_(count);
rusty1s's avatar
rusty1s committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
      }
    });
  });

  return std::make_tuple(out, arg_out);
}

template <typename scalar_t>
__global__ void
gather_coo_kernel(const scalar_t *src_data,
                  const at::cuda::detail::TensorInfo<int64_t, int> index_info,
                  scalar_t *out_data, size_t E, size_t N) {

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int row = index_info.data[offset];

    offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
    scalar_t val = __ldg(src_data + offset + row);

    out_data[row_idx] = val;
  }
}

template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
    scalar_t *out_data, size_t E, size_t K, size_t N) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int col_idx = thread_idx % K;

  if (thread_idx < E * K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int row = index_info.data[offset];

    offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
    scalar_t val = __ldg(src_data + offset + K * row + col_idx);

    out_data[thread_idx] = val;
  }
rusty1s's avatar
rusty1s committed
330
331
332
333
}

torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
                              torch::optional<torch::Tensor> optional_out) {
rusty1s's avatar
rusty1s committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
  CHECK_CUDA(src);
  CHECK_CUDA(index);
  if (optional_out.has_value())
    CHECK_CUDA(optional_out.value());
  cudaSetDevice(src.get_device());

  CHECK_INPUT(src.dim() >= index.dim());

  auto sizes = index.sizes().vec();
  for (auto i = 0; i < index.dim() - 1; i++)
    sizes[i] = src.size(i);
  index = index.expand(sizes);

  auto dim = index.dim() - 1;

  src = src.contiguous();

  torch::Tensor out;
  if (optional_out.has_value()) {
    out = optional_out.value().contiguous();
    for (auto i = 0; i < src.dim(); i++)
      if (i != dim)
        CHECK_INPUT(src.size(i) == out.size(i));
    CHECK_INPUT(index.size(dim) == out.size(dim));
  } else {
    auto sizes = src.sizes().vec();
    sizes[dim] = index.size(dim);
    out = torch::empty(sizes, src.options());
  }

rusty1s's avatar
rusty1s committed
364
365
366
  if (index.numel() == 0)
    return out;

rusty1s's avatar
rusty1s committed
367
368
369
370
371
372
  auto E = index.numel();
  auto K = out.numel() / E;
  auto N = src.size(dim);

  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
Jacob Zhong's avatar
Jacob Zhong committed
373
  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
rusty1s's avatar
rusty1s committed
374
375
376
377
378
379
380
381
382
383
384
385
386
    auto src_data = src.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    if (K == 1)
      gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, E, N);
    else
      gather_coo_broadcast_kernel<scalar_t>
          <<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
                                                     out_data, E, K, N);
  });

  return out;
rusty1s's avatar
rusty1s committed
387
}