weighting_cuda.cu 8.83 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "weighting_cuda.h"
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/CUDAContext.h>

rusty1s's avatar
rusty1s committed
5
6
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void
spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
                           const scalar_t *basis, const int64_t *weight_index,
rusty1s's avatar
rusty1s committed
14
                           scalar_t *out, int64_t E, int64_t M_in,
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
                           int64_t M_out, int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;

  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
22
    scalar_t v = (scalar_t)0.;
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    for (ptrdiff_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];
      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        scalar_t tmp = weight[wi * M_in * M_out + m_in * M_out + m_out];
        tmp *= b * x[e * M_in + m_in];
        v += tmp;
      }
    }
    out[thread_idx] = v;
  }
}

rusty1s's avatar
rusty1s committed
37
38
39
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
                                       torch::Tensor basis,
                                       torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  CHECK_CUDA(x);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
  cudaSetDevice(x.get_device());

  CHECK_INPUT(x.size(1) == weight.size(1));

  auto E = x.size(0);
  auto M_in = x.size(1);
  auto M_out = weight.size(2);
  auto S = basis.size(1);

  auto out = at::empty({E, M_out}, x.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
    auto x_data = x.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

    spline_weighting_fw_kernel<scalar_t>
        <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
            x_data, weight_data, basis_data, weight_index_data, out_data, E,
            M_in, M_out, S, out.numel());
  });

  return out;
}

template <typename scalar_t>
__global__ void
spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
                             const scalar_t *basis, const int64_t *weight_index,
                             scalar_t *grad_x, int64_t E, int64_t M_in,
                             int64_t M_out, int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_in;
  const int64_t m_in = thread_idx % M_in;

  if (thread_idx < numel) {
    scalar_t v = (scalar_t)0.;

    for (int64_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_out = 0; m_out < M_out; m_out++) {
        scalar_t tmp = weight[wi * M_in * M_out + m_out * M_out + m_in];
        tmp *= b * grad_out[e * M_out + m_out];
        v += tmp;
      }
    }
rusty1s's avatar
rusty1s committed
97
    grad_x[thread_idx] = v;
rusty1s's avatar
rusty1s committed
98
  }
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
}

torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
                                         torch::Tensor weight,
                                         torch::Tensor basis,
                                         torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  CHECK_CUDA(grad_out);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
  cudaSetDevice(grad_out.get_device());

  CHECK_INPUT(grad_out.size(1) == weight.size(2));

  auto E = grad_out.size(0);
  auto M_in = weight.size(1);
  auto M_out = grad_out.size(1);
  auto S = basis.size(1);

  auto grad_x = at::zeros({E, M_in}, grad_out.options());
rusty1s's avatar
rusty1s committed
119
  weight = weight.transpose(1, 2).contiguous();
rusty1s's avatar
rusty1s committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

  auto weight_index_data = weight_index.data_ptr<int64_t>();

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto grad_x_data = grad_x.data_ptr<scalar_t>();

    spline_weighting_bw_x_kernel<scalar_t>
        <<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
            grad_out_data, weight_data, basis_data, weight_index_data,
            grad_x_data, E, M_in, M_out, S, grad_x.numel());
  });

  return grad_x;
rusty1s's avatar
rusty1s committed
137
138
}

rusty1s's avatar
rusty1s committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
template <typename scalar_t>
spline_weighting_bw_weight_kernel(const scalar_t *grad_out, const scalar_t *x,
                                  const scalar_t *basis,
                                  const int64_t *weight_index, scalar_t *grad_x,
                                  int64_t E, int64_t M_in, int64_t M_out,
                                  int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;

  if (thread_idx < numel) {
    auto g = grad_out[e * M_out + m_out];
    for (int64_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        auto v = g * b * x[e * M_in + m_in];
        atomicAdd(&grad_weight[wi * M_in * M_out + m_in * M_out + m_out], v);
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
164
165
166
167
168
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
                                              torch::Tensor x,
                                              torch::Tensor basis,
                                              torch::Tensor weight_index,
                                              int64_t kernel_size) {
rusty1s's avatar
rusty1s committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  CHECK_CUDA(grad_out);
  CHECK_CUDA(x);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
  cudaSetDevice(grad_out.get_device());

  auto E = grad_out.size(0);
  auto M_in = x.size(1);
  auto M_out = grad_out.size(1);
  auto S = basis.size(1);

  auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto x_data = x.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto grad_weight_data = grad_weight.data_ptr<scalar_t>();

    spline_weighting_bw_weight_kernel<scalar_t>
        <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
            grad_out_data, x_data, basis_data, weight_index_data,
            grad_weight_data, E, M_in, M_out, S, grad_out.numel());
  });

  return grad_weight;
}

template <typename scalar_t>
spline_weighting_bw_basis_kernel(const scalar_t *grad_out, const scalar_t *x,
                                 const scalar_t *weight,
                                 const int64_t *weight_index,
                                 scalar_t *grad_basis, int64_t E, int64_t M_in,
                                 int64_t M_out, int64_t S, int64_t numel) {

  const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = i / M_out;
  const int64_t m_out = i % M_out;

  if (thread_idx < numel) {
    const scalar_t g = grad_out[e * M_out + m_out];

    for (int64_t s = 0; s < S; s++) {
      scalar_t v = (scalar_t)0.;
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        const scalar_t w = weight[wi * M_in * M_out + m_in * M_out + m_out];
        v += g * w * x[e * M_in + m_in];
      }
      atomicAdd(&grad_basis[e * S + s], v);
    }
  }
rusty1s's avatar
rusty1s committed
225
226
227
228
229
230
}

torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
                                             torch::Tensor x,
                                             torch::Tensor weight,
                                             torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
  CHECK_CPU(grad_out);
  CHECK_CPU(x);
  CHECK_CPU(weight);
  CHECK_CPU(weight_index);
  cudaSetDevice(grad_out.get_device());

  CHECK_INPUT(x.size(1) == weight.size(1));
  CHECK_INPUT(grad_out.size(1) == weight.size(2));

  auto E = grad_out.size(0);
  auto M_in = x.size(1);
  auto M_out = grad_out.size(1);
  auto S = weight_index.size(1);

  auto grad_basis = at::zeros({E, S}, grad_out.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto x_data = x.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto grad_basis_data = grad_basis.data_ptr<scalar_t>();

    spline_weighting_bw_basis_kernel<scalar_t>
        <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
            grad_out_data, x_data, weight_data, weight_index_data,
            grad_basis_data, E, M_in, M_out, S, grad_out.numel());
  });

  return grad_basis;
rusty1s's avatar
rusty1s committed
263
}