basis_kernel.cu 14.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>

rusty1s's avatar
rusty1s committed
5
6
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t> struct BasisForward {
  static inline __device__ scalar_t linear(scalar_t v, int64_t k_mod) {
    return 1 - v - k_mod + 2 * v * k_mod;
  }

  static inline __device__ scalar_t quadratic(scalar_t v, int64_t k_mod) {
    if (k_mod == 0)
      return 0.5 * v * v - v + 0.5;
    else if (k_mod == 1)
      return -v * v + v + 0.5;
    else
      return 0.5 * v * v;
  }

  static inline __device__ scalar_t cubic(scalar_t v, int64_t k_mod) {
    if (k_mod == 0)
      return (1 - v) * (1 - v) * (1 - v) / 6.0;
    else if (k_mod == 1)
      return (3 * v * v * v - 6 * v * v + 4) / 6;
    else if (k_mod == 2)
      return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
    else
      return v * v * v / 6;
  }
};

#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, KERNEL_NAME)     \
  [&]() -> std::tuple<at::Tensor, at::Tensor> {                                \
rusty1s's avatar
rusty1s committed
38
    cudaSetDevice(PSEUDO.get_device());                                        \
rusty1s's avatar
rusty1s committed
39
    auto E = PSEUDO.size(0);                                                   \
rusty1s's avatar
rusty1s committed
40
    auto S = (int64_t)(powf(M + 1, KERNEL_SIZE.size(0)) + 0.5);                \
rusty1s's avatar
rusty1s committed
41
42
    auto basis = at::empty({E, S}, PSEUDO.options());                          \
    auto weight_index = at::empty({E, S}, KERNEL_SIZE.options());              \
rusty1s's avatar
rusty1s committed
43
                                                                               \
rusty1s's avatar
rusty1s committed
44
45
46
47
48
49
    AT_DISPATCH_FLOATING_TYPES(                                                \
        PSEUDO.scalar_type(), "basis_forward_##M", [&] {                       \
          KERNEL_NAME<scalar_t><<<BLOCKS(basis.numel()), THREADS>>>(           \
              at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),       \
              at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index), \
              at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO),      \
rusty1s's avatar
rusty1s committed
50
51
              KERNEL_SIZE.DATA_PTR<int64_t>(),                                 \
              IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), basis.numel());              \
rusty1s's avatar
rusty1s committed
52
        });                                                                    \
rusty1s's avatar
rusty1s committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
                                                                               \
    return std::make_tuple(basis, weight_index);                               \
  }()

#define BASIS_FORWARD_KERNEL(M, BASIS, WEIGHT_INDEX, PSEUDO, KERNEL_SIZE,      \
                             IS_OPEN_SPLINE, NUMEL, CODE)                      \
  [&] {                                                                        \
    const size_t index = blockIdx.x * blockDim.x + threadIdx.x;                \
    const size_t stride = blockDim.x * gridDim.x;                              \
    for (ptrdiff_t i = index; i < NUMEL; i += stride) {                        \
      int64_t e = i / BASIS.sizes[1], s = i % BASIS.sizes[1];                  \
      int64_t k = s, wi = 0, wi_offset = 1;                                    \
      scalar_t b = 1;                                                          \
                                                                               \
      for (ptrdiff_t d = 0; d < PSEUDO.sizes[1]; d++) {                        \
        auto k_mod = k % (M + 1);                                              \
        k /= M + 1;                                                            \
                                                                               \
        auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]];   \
        v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d];                           \
                                                                               \
        wi += (((int64_t)v + k_mod) % KERNEL_SIZE[d]) * wi_offset;             \
        wi_offset *= KERNEL_SIZE[d];                                           \
                                                                               \
        v -= floor(v);                                                         \
        v = CODE;                                                              \
        b *= v;                                                                \
      }                                                                        \
                                                                               \
rusty1s's avatar
rusty1s committed
82
83
      BASIS.data[i] = b;                                                       \
      WEIGHT_INDEX.data[i] = wi;                                               \
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    }                                                                          \
  }()

template <typename scalar_t>
__global__ void
linear_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
                 at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
                 at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
                 int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
  BASIS_FORWARD_KERNEL(1, basis, weight_index, pseudo, kernel_size,
                       is_open_spline, numel,
                       BasisForward<scalar_t>::linear(v, k_mod));
}

std::tuple<at::Tensor, at::Tensor> linear_fw_cuda(at::Tensor pseudo,
                                                  at::Tensor kernel_size,
                                                  at::Tensor is_open_spline) {
  return BASIS_FORWARD(1, pseudo, kernel_size, is_open_spline,
                       linear_fw_kernel);
}

template <typename scalar_t>
__global__ void
quadratic_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
                    at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
                    at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
                    int64_t *kernel_size, uint8_t *is_open_spline,
                    size_t numel) {
  BASIS_FORWARD_KERNEL(2, basis, weight_index, pseudo, kernel_size,
                       is_open_spline, numel,
                       BasisForward<scalar_t>::quadratic(v, k_mod));
}

std::tuple<at::Tensor, at::Tensor>
quadratic_fw_cuda(at::Tensor pseudo, at::Tensor kernel_size,
                  at::Tensor is_open_spline) {
  return BASIS_FORWARD(2, pseudo, kernel_size, is_open_spline,
                       quadratic_fw_kernel);
}

template <typename scalar_t>
__global__ void
cubic_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
                at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
                at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
                int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
  BASIS_FORWARD_KERNEL(3, basis, weight_index, pseudo, kernel_size,
                       is_open_spline, numel,
                       BasisForward<scalar_t>::cubic(v, k_mod));
}

std::tuple<at::Tensor, at::Tensor> cubic_fw_cuda(at::Tensor pseudo,
                                                 at::Tensor kernel_size,
                                                 at::Tensor is_open_spline) {
  return BASIS_FORWARD(3, pseudo, kernel_size, is_open_spline, cubic_fw_kernel);
}

template <typename scalar_t> struct BasisBackward {
  static inline __device__ scalar_t linear(scalar_t v, int64_t k_mod) {
    return 2 * k_mod - 1;
  }

  static inline __device__ scalar_t quadratic(scalar_t v, int64_t k_mod) {
    if (k_mod == 0)
      return v - 1;
    else if (k_mod == 1)
      return -2 * v + 1;
    else
      return v;
  }

  static inline __device__ scalar_t cubic(scalar_t v, int64_t k_mod) {
    if (k_mod == 0)
      return (-v * v + 2 * v - 1) / 2;
    else if (k_mod == 1)
      return (3 * v * v - 4 * v) / 2;
    else if (k_mod == 2)
      return (-3 * v * v + 2 * v + 1) / 2;
    else
      return v * v / 2;
  }
};

#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE,     \
                       KERNEL_NAME)                                            \
  [&]() -> at::Tensor {                                                        \
rusty1s's avatar
rusty1s committed
170
    cudaSetDevice(GRAD_BASIS.get_device());                                    \
rusty1s's avatar
rusty1s committed
171
172
    auto E = PSEUDO.size(0);                                                   \
    auto D = PSEUDO.size(1);                                                   \
rusty1s's avatar
rusty1s committed
173
    auto grad_pseudo = at::empty({E, D}, PSEUDO.options());                    \
rusty1s's avatar
rusty1s committed
174
                                                                               \
rusty1s's avatar
rusty1s committed
175
176
177
178
179
180
    AT_DISPATCH_FLOATING_TYPES(                                                \
        GRAD_BASIS.scalar_type(), "basis_backward_##M", [&] {                  \
          KERNEL_NAME<scalar_t><<<BLOCKS(grad_pseudo.numel()), THREADS>>>(     \
              at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_pseudo), \
              at::cuda::detail::getTensorInfo<scalar_t, int64_t>(GRAD_BASIS),  \
              at::cuda::detail::getTensorInfo<scalar_t, int64_t>(PSEUDO),      \
rusty1s's avatar
rusty1s committed
181
182
              KERNEL_SIZE.DATA_PTR<int64_t>(),                                 \
              IS_OPEN_SPLINE.DATA_PTR<uint8_t>(), grad_pseudo.numel());        \
rusty1s's avatar
rusty1s committed
183
        });                                                                    \
rusty1s's avatar
rusty1s committed
184
185
                                                                               \
    return grad_pseudo;                                                        \
rusty1s's avatar
rusty1s committed
186
  }()
rusty1s's avatar
rusty1s committed
187

rusty1s's avatar
rusty1s committed
188
189
190
191
192
193
#define BASIS_BACKWARD_KERNEL(M, GRAD_PSEUDO, GRAD_BASIS, PSEUDO, KERNEL_SIZE, \
                              IS_OPEN_SPLINE, NUMEL, CODE, GRAD_CODE)          \
  [&] {                                                                        \
    const size_t index = blockIdx.x * blockDim.x + threadIdx.x;                \
    const size_t stride = blockDim.x * gridDim.x;                              \
    for (ptrdiff_t i = index; i < NUMEL; i += stride) {                        \
rusty1s's avatar
rusty1s committed
194
195
196
197
      int64_t e = i / GRAD_PSEUDO.sizes[1], d = i % GRAD_PSEUDO.sizes[1];      \
      scalar_t g = 0, tmp;                                                     \
                                                                               \
      for (ptrdiff_t s = 0; s < GRAD_BASIS.sizes[1]; s++) {                    \
rusty1s's avatar
rusty1s committed
198
        auto k_mod = (s / (int64_t)(powf(M + 1, d) + 0.5)) % (M + 1);          \
rusty1s's avatar
rusty1s committed
199
200
201
        auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]];   \
        v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d];                           \
        v -= floor(v);                                                         \
rusty1s's avatar
rusty1s committed
202
        v = GRAD_CODE;                                                         \
rusty1s's avatar
rusty1s committed
203
204
205
206
        tmp = v;                                                               \
                                                                               \
        for (ptrdiff_t d_it = 1; d_it < GRAD_PSEUDO.sizes[1]; d_it++) {        \
          auto d_new = d_it - (d >= d_it);                                     \
rusty1s's avatar
rusty1s committed
207
          k_mod = (s / (int64_t)(powf(M + 1, d_new) + 0.5)) % (M + 1);         \
rusty1s's avatar
rusty1s committed
208
209
210
          v = PSEUDO.data[e * pseudo.strides[0] + d_new * PSEUDO.strides[1]];  \
          v *= KERNEL_SIZE[d_new] - M * IS_OPEN_SPLINE[d_new];                 \
          v -= floor(v);                                                       \
rusty1s's avatar
rusty1s committed
211
          v = CODE;                                                            \
rusty1s's avatar
rusty1s committed
212
213
214
215
216
217
218
          tmp *= v;                                                            \
        }                                                                      \
        g += tmp *                                                             \
             GRAD_BASIS                                                        \
                 .data[e * GRAD_BASIS.strides[0] + s * GRAD_BASIS.strides[1]]; \
      }                                                                        \
      g *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d];                             \
rusty1s's avatar
rusty1s committed
219
      GRAD_PSEUDO.data[i] = g;                                                 \
rusty1s's avatar
rusty1s committed
220
221
222
    }                                                                          \
  }()

rusty1s's avatar
rusty1s committed
223
224
225
226
227
228
229
230
231
232
233
234
template <typename scalar_t>
__global__ void
linear_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
                 at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
                 at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
                 int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
  BASIS_BACKWARD_KERNEL(1, grad_pseudo, grad_basis, pseudo, kernel_size,
                        is_open_spline, numel,
                        BasisForward<scalar_t>::linear(v, k_mod),
                        BasisBackward<scalar_t>::linear(v, k_mod));
}

rusty1s's avatar
rusty1s committed
235
236
at::Tensor linear_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
                          at::Tensor kernel_size, at::Tensor is_open_spline) {
rusty1s's avatar
rusty1s committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
  return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
                        linear_bw_kernel);
}

template <typename scalar_t>
__global__ void
quadratic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
                    at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
                    at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
                    int64_t *kernel_size, uint8_t *is_open_spline,
                    size_t numel) {
  BASIS_BACKWARD_KERNEL(2, grad_pseudo, grad_basis, pseudo, kernel_size,
                        is_open_spline, numel,
                        BasisForward<scalar_t>::quadratic(v, k_mod),
                        BasisBackward<scalar_t>::quadratic(v, k_mod));
rusty1s's avatar
rusty1s committed
252
253
254
255
256
}

at::Tensor quadratic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
                             at::Tensor kernel_size,
                             at::Tensor is_open_spline) {
rusty1s's avatar
rusty1s committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
  return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
                        quadratic_bw_kernel);
}

template <typename scalar_t>
__global__ void
cubic_bw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_pseudo,
                at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
                at::cuda::detail::TensorInfo<scalar_t, int64_t> pseudo,
                int64_t *kernel_size, uint8_t *is_open_spline, size_t numel) {
  BASIS_BACKWARD_KERNEL(3, grad_pseudo, grad_basis, pseudo, kernel_size,
                        is_open_spline, numel,
                        BasisForward<scalar_t>::cubic(v, k_mod),
                        BasisBackward<scalar_t>::cubic(v, k_mod));
rusty1s's avatar
rusty1s committed
271
272
273
274
}

at::Tensor cubic_bw_cuda(at::Tensor grad_basis, at::Tensor pseudo,
                         at::Tensor kernel_size, at::Tensor is_open_spline) {
rusty1s's avatar
rusty1s committed
275
276
  return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
                        cubic_bw_kernel);
rusty1s's avatar
rusty1s committed
277
}