scatter_kernel.cu 9.09 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>

#include "atomics.cuh"
#include "index.cuh"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

#define KERNEL_RUN(NAME, DIMS, N, ...)                                         \
  [&] {                                                                        \
    switch (DIMS) {                                                            \
    case 1:                                                                    \
      NAME<scalar_t, 1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);               \
      break;                                                                   \
    case 2:                                                                    \
      NAME<scalar_t, 2><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);               \
      break;                                                                   \
    case 3:                                                                    \
      NAME<scalar_t, 3><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);               \
      break;                                                                   \
    default:                                                                   \
      NAME<scalar_t, -1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);              \
    }                                                                          \
  }()

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomMul(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      int64_t dim) {
  AT_DISPATCH_ALL_TYPES(src.type(), "scatter_mul_kernel", [&] {
    KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
               at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
  });
}

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomDiv(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      int64_t dim) {
  AT_DISPATCH_ALL_TYPES(src.type(), "scatter_div_kernel", [&] {
    KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
               at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
  });
}

template <typename scalar_t, int64_t Dims>
__global__ void arg_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                           at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                           at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                           at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
                           int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0;
    IndexToScatterOffsets4<scalar_t, scalar_t, int64_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg,
        &argOffset);
    if (src.data[srcOffset] == out.data[outOffset]) {
      arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim];
    }
  }
}

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomMax(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      at::Tensor arg, int64_t dim) {
  AT_DISPATCH_ALL_TYPES(src.type(), "scatter_max_kernel", [&] {
    auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
    auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
    auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
    KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info,
               index_info, out_info, dim);
    KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
               out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
               dim);
  });
}

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomMin(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      at::Tensor arg, int64_t dim) {
  AT_DISPATCH_ALL_TYPES(src.type(), "scatter_min_kernel", [&] {
    auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
    auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
    auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
    KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info,
               index_info, out_info, dim);
    KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
               out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
               dim);
  });
}

template <typename scalar_t, int64_t Dims>
__global__ void
index_backward_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad,
                      at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                      at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
                      at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                      int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t gradOffset = 0, indexOffset = 0, argOffset = 0, outOffset = 0;
    IndexToScatterOffsets4<scalar_t, int64_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, out, &outOffset, arg, &argOffset, grad,
        &gradOffset);
    if (arg.data[argOffset] ==
        (outOffset / out.strides[dim]) % out.sizes[dim]) {
      out.data[outOffset] = grad.data[gradOffset];
    }
  }
}

void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg,
                         at::Tensor out, int64_t dim) {
  AT_DISPATCH_ALL_TYPES(grad.type(), "index_backward_kernel", [&] {
    KERNEL_RUN(index_backward_kernel, index.dim(), index.numel(),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad),
               at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
               at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
  });
}