"tests/vscode:/vscode.git/clone" did not exist on "d520d24fdf2eb7c2e76aa5ca3020cbfd07c42910"
basis_cuda.cu 6.86 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "basis_cuda.h"

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/CUDAContext.h>

rusty1s's avatar
rusty1s committed
5
6
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t, int64_t degree> struct Basis {
  static inline __device__ scalar_t forward(scalar_t v, int64_t k_mod) {
    if (degree == 1) {
      return 1. - v - k_mod + 2. * v * k_mod;
    } else if (degree == 2) {
      if (k_mod == 0)
        return 0.5 * v * v - v + 0.5;
      else if (k_mod == 1)
        return -v * v + v + 0.5;
      else
        return 0.5 * v * v;
    } else if (degree == 3) {
      if (k_mod == 0)
        return (1. - v) * (1. - v) * (1. - v) / 6.;
      else if (k_mod == 1)
        return (3. * v * v * v - 6. * v * v + 4.) / 6.;
      else if (k_mod == 2)
        return (-3. * v * v * v + 3. * v * v + 3. * v + 1.) / 6.;
      else
        return v * v * v / 6.;
    } else {
rusty1s's avatar
rusty1s committed
31
      return (scalar_t)-1.;
rusty1s's avatar
rusty1s committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    }
  }

  static inline __device__ scalar_t backward(scalar_t v, int64_t k_mod) {
    if (degree == 1) {
      return 2 * k_mod - 1;
    } else if (degree == 2) {
      if (k_mod == 0)
        return v - 1.;
      else if (k_mod == 1)
        return -2. * v + 1.;
      else
        return v;
    } else if (degree == 3) {
      if (k_mod == 0)
        return (-v * v + 2. * v - 1.) / 2.;
      else if (k_mod == 1)
        return (3. * v * v - 4. * v) / 2.;
      else if (k_mod == 2)
        return (-3. * v * v + 2. * v + 1.) / 2.;
      else
        return v * v / 2.;
    } else {
rusty1s's avatar
rusty1s committed
55
      return (scalar_t)-1.;
rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    }
  }
};

template <typename scalar_t, int64_t degree>
__global__ void
spline_basis_fw_kernel(const scalar_t *pseudo, const int64_t *kernel_size,
                       const uint8_t *is_open_spline, scalar_t *basis,
                       int64_t *weight_index, int64_t E, int64_t D, int64_t S,
                       int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / S;
  const int64_t s = thread_idx % S;

  if (thread_idx < numel) {
    int64_t k = s, wi = 0, wi_offset = 1;
    scalar_t b = (scalar_t)1.;

    for (int64_t d = 0; d < D; d++) {
rusty1s's avatar
rusty1s committed
76
      const int64_t k_mod = k % (degree + 1);
rusty1s's avatar
rusty1s committed
77
78
      k /= degree + 1;

rusty1s's avatar
rusty1s committed
79
      scalar_t v = pseudo[e * D + d];
rusty1s's avatar
rusty1s committed
80
81
82
83
84
85
86
87
88
89
      v *= kernel_size[d] - degree * is_open_spline[d];

      wi += (((int64_t)v + k_mod) % kernel_size[d]) * wi_offset;
      wi_offset *= kernel_size[d];

      v -= floor(v);
      v = Basis<scalar_t, degree>::forward(v, k_mod);
      b *= v;
    }

rusty1s's avatar
rusty1s committed
90
91
    basis[thread_idx] = b;
    weight_index[thread_idx] = wi;
rusty1s's avatar
rusty1s committed
92
93
94
  }
}

rusty1s's avatar
rusty1s committed
95
std::tuple<torch::Tensor, torch::Tensor>
rusty1s's avatar
rusty1s committed
96
97
98
99
100
spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
                     torch::Tensor is_open_spline, int64_t degree) {
  CHECK_CUDA(pseudo);
  CHECK_CUDA(kernel_size);
  CHECK_CUDA(is_open_spline);
rusty1s's avatar
rusty1s committed
101
  cudaSetDevice(pseudo.get_device());
rusty1s's avatar
rusty1s committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

  CHECK_INPUT(kernel_size.dim() == 1);
  CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
  CHECK_INPUT(is_open_spline.dim());
  CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());

  auto E = pseudo.size(0);
  auto D = pseudo.size(1);
  auto S = (int64_t)(powf(degree + 1, D) + 0.5);

  auto basis = at::empty({E, S}, pseudo.options());
  auto weight_index = at::empty({E, S}, kernel_size.options());

  auto kernel_size_data = kernel_size.data_ptr<int64_t>();
  auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
119
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
120
121
122
123
124
  AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
    auto pseudo_data = pseudo.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();

    AT_DISPATCH_DEGREE_TYPES(degree, [&] {
rusty1s's avatar
rusty1s committed
125
126
127
128
      spline_basis_fw_kernel<scalar_t, DEGREE>
          <<<BLOCKS(basis.numel()), THREADS, 0, stream>>>(
              pseudo_data, kernel_size_data, is_open_spline_data, basis_data,
              weight_index_data, E, D, S, basis.numel());
rusty1s's avatar
rusty1s committed
129
130
131
132
    });
  });

  return std::make_tuple(basis, weight_index);
rusty1s's avatar
rusty1s committed
133
134
}

rusty1s's avatar
rusty1s committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
template <typename scalar_t, int64_t degree>
__global__ void
spline_basis_bw_kernel(const scalar_t *grad_basis, const scalar_t *pseudo,
                       const int64_t *kernel_size,
                       const uint8_t *is_open_spline, scalar_t *grad_pseudo,
                       int64_t E, int64_t D, int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / D;
  const int64_t d = thread_idx % D;

  if (thread_idx < numel) {
    scalar_t g = (scalar_t)0., tmp;

    for (ptrdiff_t s = 0; s < S; s++) {
      int64_t k_mod = (s / (int64_t)(powf(degree + 1, d) + 0.5)) % (degree + 1);

rusty1s's avatar
rusty1s committed
152
      scalar_t v = pseudo[e * D + d];
rusty1s's avatar
rusty1s committed
153
154
155
156
157
158
      v *= kernel_size[d] - degree * is_open_spline[d];
      v -= floor(v);
      v = Basis<scalar_t, degree>::backward(v, k_mod);
      tmp = v;

      for (int64_t d_it = 1; d_it < D; d_it++) {
rusty1s's avatar
rusty1s committed
159
        const int64_t d_new = d_it - (d >= d_it);
rusty1s's avatar
rusty1s committed
160
161
162
163
        k_mod = (s / (int64_t)(powf(degree + 1, d_new) + 0.5)) % (degree + 1);
        v = pseudo[e * D + d_new];
        v *= kernel_size[d_new] - degree * is_open_spline[d_new];
        v -= floor(v);
rusty1s's avatar
rusty1s committed
164
        v = Basis<scalar_t, degree>::forward(v, k_mod);
rusty1s's avatar
rusty1s committed
165
166
167
168
169
        tmp *= v;
      }
      g += tmp * grad_basis[e * S + s];
    }
    g *= kernel_size[d] - degree * is_open_spline[d];
rusty1s's avatar
rusty1s committed
170
    grad_pseudo[thread_idx] = g;
rusty1s's avatar
rusty1s committed
171
172
173
174
175
176
177
178
179
180
181
182
  }
}

torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
                                   torch::Tensor pseudo,
                                   torch::Tensor kernel_size,
                                   torch::Tensor is_open_spline,
                                   int64_t degree) {
  CHECK_CUDA(grad_basis);
  CHECK_CUDA(pseudo);
  CHECK_CUDA(kernel_size);
  CHECK_CUDA(is_open_spline);
rusty1s's avatar
rusty1s committed
183
  cudaSetDevice(grad_basis.get_device());
rusty1s's avatar
rusty1s committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

  CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
  CHECK_INPUT(kernel_size.dim() == 1);
  CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
  CHECK_INPUT(is_open_spline.dim());
  CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());

  auto E = pseudo.size(0);
  auto D = pseudo.size(1);
  auto S = grad_basis.size(1);

  auto grad_pseudo = at::empty({E, D}, pseudo.options());

  auto kernel_size_data = kernel_size.data_ptr<int64_t>();
  auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();

rusty1s's avatar
rusty1s committed
200
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
201
202
203
204
205
206
  AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
    auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
    auto pseudo_data = pseudo.data_ptr<scalar_t>();
    auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();

    AT_DISPATCH_DEGREE_TYPES(degree, [&] {
rusty1s's avatar
rusty1s committed
207
208
209
210
211
      spline_basis_bw_kernel<scalar_t, DEGREE>
          <<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>(
              grad_basis_data, pseudo_data, kernel_size_data,
              is_open_spline_data, grad_pseudo_data, E, D, S,
              grad_pseudo.numel());
rusty1s's avatar
rusty1s committed
212
213
214
215
    });
  });

  return grad_pseudo;
rusty1s's avatar
rusty1s committed
216
}