scatter_kernel.cu 7.77 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <ATen/ATen.h>
rusty1s's avatar
rusty1s committed
2
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
9
10
11
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>

#include "atomics.cuh"
#include "index.cuh"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

rusty1s's avatar
rusty1s committed
12
13
auto stream = at::cuda::getCurrentCUDAStream();

rusty1s's avatar
rusty1s committed
14
15
#define KERNEL_RUN(NAME, DIMS, N, ...)                                         \
  [&] {                                                                        \
rusty1s's avatar
rusty1s committed
16
    auto stream = at::cuda::getCurrentCUDAStream();                            \
rusty1s's avatar
rusty1s committed
17
18
    switch (DIMS) {                                                            \
    case 1:                                                                    \
rusty1s's avatar
rusty1s committed
19
      NAME<scalar_t, 1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N);    \
rusty1s's avatar
rusty1s committed
20
21
      break;                                                                   \
    case 2:                                                                    \
rusty1s's avatar
rusty1s committed
22
      NAME<scalar_t, 2><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N);    \
rusty1s's avatar
rusty1s committed
23
24
      break;                                                                   \
    case 3:                                                                    \
rusty1s's avatar
rusty1s committed
25
      NAME<scalar_t, 3><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N);    \
rusty1s's avatar
rusty1s committed
26
27
      break;                                                                   \
    default:                                                                   \
rusty1s's avatar
rusty1s committed
28
      NAME<scalar_t, -1><<<BLOCKS(N), THREADS, 0, stream>>>(__VA_ARGS__, N);   \
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    }                                                                          \
  }()

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomMul(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      int64_t dim) {
rusty1s's avatar
rusty1s committed
50
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
               at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
  });
}

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomDiv(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      int64_t dim) {
rusty1s's avatar
rusty1s committed
76
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
rusty1s's avatar
rusty1s committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
               at::cuda::detail::getTensorInfo<int64_t, int64_t>(index),
               at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), dim);
  });
}

template <typename scalar_t, int64_t Dims>
__global__ void arg_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                           at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                           at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                           at::cuda::detail::TensorInfo<int64_t, int64_t> arg,
                           int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0, argOffset = 0;
    IndexToScatterOffsets4<scalar_t, scalar_t, int64_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset, arg,
        &argOffset);
    if (src.data[srcOffset] == out.data[outOffset]) {
      arg.data[argOffset] = (srcOffset / src.strides[dim]) % src.sizes[dim];
    }
  }
}

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomMax(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      at::Tensor arg, int64_t dim) {
rusty1s's avatar
rusty1s committed
121
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
rusty1s's avatar
rusty1s committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
    auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
    auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
    KERNEL_RUN(scatter_max_kernel, index.dim(), index.numel(), src_info,
               index_info, out_info, dim);
    KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
               out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
               dim);
  });
}

template <typename scalar_t, int64_t Dims>
__global__ void
scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
                   at::cuda::detail::TensorInfo<int64_t, int64_t> index,
                   at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                   int64_t dim, size_t numel) {
  const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = idx; i < numel; i += stride) {
    int64_t srcOffset = 0, indexOffset = 0, outOffset = 0;
    IndexToScatterOffsets3<scalar_t, scalar_t, Dims>::compute(
        i, dim, index, &indexOffset, src, &srcOffset, out, &outOffset);
    atomMin(&out.data[outOffset], src.data[srcOffset]);
  }
}

void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                      at::Tensor arg, int64_t dim) {
rusty1s's avatar
rusty1s committed
151
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
rusty1s's avatar
rusty1s committed
152
153
154
155
156
157
158
159
160
161
    auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
    auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
    auto out_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out);
    KERNEL_RUN(scatter_min_kernel, index.dim(), index.numel(), src_info,
               index_info, out_info, dim);
    KERNEL_RUN(arg_kernel, index.dim(), index.numel(), src_info, index_info,
               out_info, at::cuda::detail::getTensorInfo<int64_t, int64_t>(arg),
               dim);
  });
}