basis.cpp 10.8 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#include "compat.h"

rusty1s's avatar
rusty1s committed
5
template <typename scalar_t> inline scalar_t linear(scalar_t v, int64_t k_mod) {
rusty1s's avatar
rusty1s committed
6
7
8
  return 1 - v - k_mod + 2 * v * k_mod;
}

rusty1s's avatar
rusty1s committed
9
10
template <typename scalar_t>
inline scalar_t quadratic(scalar_t v, int64_t k_mod) {
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
  if (k_mod == 0)
    return 0.5 * v * v - v + 0.5;
  else if (k_mod == 1)
    return -v * v + v + 0.5;
  else
    return 0.5 * v * v;
}

rusty1s's avatar
rusty1s committed
19
20
template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
  if (k_mod == 0)
rusty1s's avatar
rusty1s committed
21
    return (1 - v) * (1 - v) * (1 - v) / 6.0;
rusty1s's avatar
rusty1s committed
22
  else if (k_mod == 1)
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
    return (3 * v * v * v - 6 * v * v + 4) / 6;
  else if (k_mod == 2)
    return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
  else
    return v * v * v / 6;
}

#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, FUNC)            \
  [&]() -> std::tuple<at::Tensor, at::Tensor> {                                \
    auto E = PSEUDO.size(0), D = PSEUDO.size(1);                               \
    auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5);                 \
rusty1s's avatar
rusty1s committed
34
35
    auto basis = at::empty({E, S}, PSEUDO.options());                          \
    auto weight_index = at::empty({E, S}, KERNEL_SIZE.options());              \
rusty1s's avatar
rusty1s committed
36
                                                                               \
rusty1s's avatar
rusty1s committed
37
38
    AT_DISPATCH_FLOATING_TYPES(                                                \
        PSEUDO.scalar_type(), "basis_forward_##M", [&] {                       \
rusty1s's avatar
rusty1s committed
39
40
41
42
43
          auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>();                      \
          auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>();             \
          auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>();       \
          auto basis_data = basis.DATA_PTR<scalar_t>();                        \
          auto weight_index_data = weight_index.DATA_PTR<int64_t>();           \
rusty1s's avatar
rusty1s committed
44
                                                                               \
rusty1s's avatar
rusty1s committed
45
46
          int64_t k, wi, wi_offset;                                            \
          scalar_t b;                                                          \
rusty1s's avatar
rusty1s committed
47
                                                                               \
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
56
          for (ptrdiff_t e = 0; e < E; e++) {                                  \
            for (ptrdiff_t s = 0; s < S; s++) {                                \
              k = s;                                                           \
              wi = 0;                                                          \
              wi_offset = 1;                                                   \
              b = 1;                                                           \
              for (ptrdiff_t d = 0; d < D; d++) {                              \
                auto k_mod = k % (M + 1);                                      \
                k /= M + 1;                                                    \
rusty1s's avatar
rusty1s committed
57
                                                                               \
rusty1s's avatar
rusty1s committed
58
59
60
                auto v =                                                       \
                    pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];  \
                v *= kernel_size_data[d] - M * is_open_spline_data[d];         \
rusty1s's avatar
rusty1s committed
61
                                                                               \
rusty1s's avatar
rusty1s committed
62
63
64
                wi +=                                                          \
                    (((int64_t)v + k_mod) % kernel_size_data[d]) * wi_offset;  \
                wi_offset *= kernel_size_data[d];                              \
rusty1s's avatar
rusty1s committed
65
                                                                               \
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
                v -= floor(v);                                                 \
                v = FUNC<scalar_t>(v, k_mod);                                  \
                b *= v;                                                        \
              }                                                                \
              basis_data[e * S + s] = b;                                       \
              weight_index_data[e * S + s] = wi;                               \
            }                                                                  \
rusty1s's avatar
rusty1s committed
73
          }                                                                    \
rusty1s's avatar
rusty1s committed
74
        });                                                                    \
rusty1s's avatar
rusty1s committed
75
76
77
    return std::make_tuple(basis, weight_index);                               \
  }()

rusty1s's avatar
rusty1s committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
std::tuple<at::Tensor, at::Tensor> linear_fw(at::Tensor pseudo,
                                             at::Tensor kernel_size,
                                             at::Tensor is_open_spline) {
  return BASIS_FORWARD(1, pseudo, kernel_size, is_open_spline, linear);
}

std::tuple<at::Tensor, at::Tensor> quadratic_fw(at::Tensor pseudo,
                                                at::Tensor kernel_size,
                                                at::Tensor is_open_spline) {
  return BASIS_FORWARD(2, pseudo, kernel_size, is_open_spline, quadratic);
}

std::tuple<at::Tensor, at::Tensor>
cubic_fw(at::Tensor pseudo, at::Tensor kernel_size, at::Tensor is_open_spline) {
  return BASIS_FORWARD(3, pseudo, kernel_size, is_open_spline, cubic);
}

template <typename scalar_t>
inline scalar_t grad_linear(scalar_t v, int64_t k_mod) {
  return 2 * k_mod - 1;
}

template <typename scalar_t>
inline scalar_t grad_quadratic(scalar_t v, int64_t k_mod) {
  if (k_mod == 0)
    return v - 1;
  else if (k_mod == 1)
    return -2 * v + 1;
  else
    return v;
}

template <typename scalar_t>
inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
  if (k_mod == 0)
    return (-v * v + 2 * v - 1) / 2;
  else if (k_mod == 1)
    return (3 * v * v - 4 * v) / 2;
  else if (k_mod == 2)
    return (-3 * v * v + 2 * v + 1) / 2;
  else
    return v * v / 2;
}

rusty1s's avatar
rusty1s committed
122
123
124
125
126
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE,     \
                       FUNC, GRAD_FUNC)                                        \
  [&]() -> at::Tensor {                                                        \
    auto E = PSEUDO.size(0), D = PSEUDO.size(1);                               \
    auto S = GRAD_BASIS.size(1);                                               \
rusty1s's avatar
rusty1s committed
127
    auto grad_pseudo = at::empty({E, D}, PSEUDO.options());                    \
rusty1s's avatar
rusty1s committed
128
                                                                               \
rusty1s's avatar
rusty1s committed
129
130
    AT_DISPATCH_FLOATING_TYPES(                                                \
        PSEUDO.scalar_type(), "basis_backward_##M", [&] {                      \
rusty1s's avatar
rusty1s committed
131
132
133
134
135
          auto grad_basis_data = GRAD_BASIS.DATA_PTR<scalar_t>();              \
          auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>();                      \
          auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>();             \
          auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>();       \
          auto grad_pseudo_data = grad_pseudo.DATA_PTR<scalar_t>();            \
rusty1s's avatar
rusty1s committed
136
                                                                               \
rusty1s's avatar
rusty1s committed
137
          scalar_t g, tmp;                                                     \
rusty1s's avatar
rusty1s committed
138
                                                                               \
rusty1s's avatar
rusty1s committed
139
140
141
142
143
144
145
146
147
148
149
          for (ptrdiff_t e = 0; e < E; e++) {                                  \
            for (ptrdiff_t d = 0; d < D; d++) {                                \
              g = 0;                                                           \
              for (ptrdiff_t s = 0; s < S; s++) {                              \
                auto k_mod = (s / (int64_t)(pow(M + 1, d) + 0.5)) % (M + 1);   \
                auto v =                                                       \
                    pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];  \
                v *= kernel_size_data[d] - M * is_open_spline_data[d];         \
                v -= floor(v);                                                 \
                v = GRAD_FUNC<scalar_t>(v, k_mod);                             \
                tmp = v;                                                       \
rusty1s's avatar
rusty1s committed
150
                                                                               \
rusty1s's avatar
rusty1s committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                for (ptrdiff_t d_it = 1; d_it < D; d_it++) {                   \
                  auto d_new = d_it - (d >= d_it);                             \
                  k_mod = (s / (int64_t)(pow(M + 1, d_new) + 0.5)) % (M + 1);  \
                  v = pseudo_data[e * pseudo.stride(0) +                       \
                                  d_new * pseudo.stride(1)];                   \
                  v *= kernel_size_data[d_new] -                               \
                       M * is_open_spline_data[d_new];                         \
                  v -= floor(v);                                               \
                  v = FUNC<scalar_t>(v, k_mod);                                \
                  tmp *= v;                                                    \
                }                                                              \
                g += tmp * grad_basis_data[e * grad_basis.stride(0) +          \
                                           s * grad_basis.stride(1)];          \
              }                                                                \
              g *= kernel_size_data[d] - M * is_open_spline_data[d];           \
              grad_pseudo_data[e * D + d] = g;                                 \
rusty1s's avatar
rusty1s committed
167
168
            }                                                                  \
          }                                                                    \
rusty1s's avatar
rusty1s committed
169
        });                                                                    \
rusty1s's avatar
rusty1s committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    return grad_pseudo;                                                        \
  }()

at::Tensor linear_bw(at::Tensor grad_basis, at::Tensor pseudo,
                     at::Tensor kernel_size, at::Tensor is_open_spline) {
  return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
                        linear, grad_linear);
}

at::Tensor quadratic_bw(at::Tensor grad_basis, at::Tensor pseudo,
                        at::Tensor kernel_size, at::Tensor is_open_spline) {
  return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
                        quadratic, grad_quadratic);
}

at::Tensor cubic_bw(at::Tensor grad_basis, at::Tensor pseudo,
                    at::Tensor kernel_size, at::Tensor is_open_spline) {
  return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
                        cubic, grad_cubic);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("linear_fw", &linear_fw, "Linear Basis Forward (CPU)");
  m.def("quadratic_fw", &quadratic_fw, "Quadratic Basis Forward (CPU)");
  m.def("cubic_fw", &cubic_fw, "Cubic Basis Forward (CPU)");
  m.def("linear_bw", &linear_bw, "Linear Basis Backward (CPU)");
  m.def("quadratic_bw", &quadratic_bw, "Quadratic Basis Backward (CPU)");
  m.def("cubic_bw", &cubic_bw, "Cubic Basis Backward (CPU)");
}