weighting_cuda.cu 8.68 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "weighting_cuda.h"
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/CUDAContext.h>

rusty1s's avatar
rusty1s committed
5
6
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void
spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
                           const scalar_t *basis, const int64_t *weight_index,
rusty1s's avatar
rusty1s committed
14
                           scalar_t *out, int64_t E, int64_t M_in,
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
                           int64_t M_out, int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;

  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
22
    scalar_t v = (scalar_t)0.;
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    for (ptrdiff_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];
      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        scalar_t tmp = weight[wi * M_in * M_out + m_in * M_out + m_out];
        tmp *= b * x[e * M_in + m_in];
        v += tmp;
      }
    }
    out[thread_idx] = v;
  }
}

rusty1s's avatar
rusty1s committed
37
38
39
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
                                       torch::Tensor basis,
                                       torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
40
41
42
43
  CHECK_CUDA(x);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
44
  cudaSetDevice(x.get_device());
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
56

  CHECK_INPUT(x.size(1) == weight.size(1));

  auto E = x.size(0);
  auto M_in = x.size(1);
  auto M_out = weight.size(2);
  auto S = basis.size(1);

  auto out = at::empty({E, M_out}, x.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
57
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
    auto x_data = x.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
64
65
66
67
    spline_weighting_fw_kernel<scalar_t>
        <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
            x_data, weight_data, basis_data, weight_index_data, out_data, E,
            M_in, M_out, S, out.numel());
rusty1s's avatar
rusty1s committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  });

  return out;
}

template <typename scalar_t>
__global__ void
spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
                             const scalar_t *basis, const int64_t *weight_index,
                             scalar_t *grad_x, int64_t E, int64_t M_in,
                             int64_t M_out, int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_in;
  const int64_t m_in = thread_idx % M_in;

  if (thread_idx < numel) {
    scalar_t v = (scalar_t)0.;

    for (int64_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_out = 0; m_out < M_out; m_out++) {
rusty1s's avatar
rusty1s committed
92
        scalar_t tmp = weight[wi * M_out * M_in + m_out * M_in + m_in];
rusty1s's avatar
rusty1s committed
93
94
95
96
        tmp *= b * grad_out[e * M_out + m_out];
        v += tmp;
      }
    }
rusty1s's avatar
rusty1s committed
97
    grad_x[thread_idx] = v;
rusty1s's avatar
rusty1s committed
98
  }
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
}

torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
                                         torch::Tensor weight,
                                         torch::Tensor basis,
                                         torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
105
106
107
108
  CHECK_CUDA(grad_out);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
109
  cudaSetDevice(grad_out.get_device());
rusty1s's avatar
rusty1s committed
110
111
112
113
114
115
116
117
118

  CHECK_INPUT(grad_out.size(1) == weight.size(2));

  auto E = grad_out.size(0);
  auto M_in = weight.size(1);
  auto M_out = grad_out.size(1);
  auto S = basis.size(1);

  auto grad_x = at::zeros({E, M_in}, grad_out.options());
rusty1s's avatar
rusty1s committed
119
  weight = weight.transpose(1, 2).contiguous(); // Contiguous memory-access.
rusty1s's avatar
rusty1s committed
120
121
122

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
123
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
124
125
126
127
128
129
  AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto grad_x_data = grad_x.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
130
131
132
133
    spline_weighting_bw_x_kernel<scalar_t>
        <<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
            grad_out_data, weight_data, basis_data, weight_index_data,
            grad_x_data, E, M_in, M_out, S, grad_x.numel());
rusty1s's avatar
rusty1s committed
134
135
136
  });

  return grad_x;
rusty1s's avatar
rusty1s committed
137
138
}

rusty1s's avatar
rusty1s committed
139
template <typename scalar_t>
rusty1s's avatar
rusty1s committed
140
141
142
143
__global__ void spline_weighting_bw_weight_kernel(
    const scalar_t *grad_out, const scalar_t *x, const scalar_t *basis,
    const int64_t *weight_index, scalar_t *grad_weight, int64_t E, int64_t M_in,
    int64_t M_out, int64_t S, int64_t numel) {
rusty1s's avatar
rusty1s committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;

  if (thread_idx < numel) {
    auto g = grad_out[e * M_out + m_out];
    for (int64_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        auto v = g * b * x[e * M_in + m_in];
        atomicAdd(&grad_weight[wi * M_in * M_out + m_in * M_out + m_out], v);
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
163
164
165
166
167
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
                                              torch::Tensor x,
                                              torch::Tensor basis,
                                              torch::Tensor weight_index,
                                              int64_t kernel_size) {
rusty1s's avatar
rusty1s committed
168
169
170
171
  CHECK_CUDA(grad_out);
  CHECK_CUDA(x);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
172
  cudaSetDevice(grad_out.get_device());
rusty1s's avatar
rusty1s committed
173
174
175
176
177
178
179
180
181
182

  auto E = grad_out.size(0);
  auto M_in = x.size(1);
  auto M_out = grad_out.size(1);
  auto S = basis.size(1);

  auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
183
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
184
185
186
187
188
189
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto x_data = x.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto grad_weight_data = grad_weight.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
190
191
192
193
    spline_weighting_bw_weight_kernel<scalar_t>
        <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
            grad_out_data, x_data, basis_data, weight_index_data,
            grad_weight_data, E, M_in, M_out, S, grad_out.numel());
rusty1s's avatar
rusty1s committed
194
195
196
197
198
199
  });

  return grad_weight;
}

template <typename scalar_t>
rusty1s's avatar
rusty1s committed
200
201
202
203
__global__ void spline_weighting_bw_basis_kernel(
    const scalar_t *grad_out, const scalar_t *x, const scalar_t *weight,
    const int64_t *weight_index, scalar_t *grad_basis, int64_t E, int64_t M_in,
    int64_t M_out, int64_t S, int64_t numel) {
rusty1s's avatar
rusty1s committed
204
205

  const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
206
207
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;
rusty1s's avatar
rusty1s committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

  if (thread_idx < numel) {
    const scalar_t g = grad_out[e * M_out + m_out];

    for (int64_t s = 0; s < S; s++) {
      scalar_t v = (scalar_t)0.;
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        const scalar_t w = weight[wi * M_in * M_out + m_in * M_out + m_out];
        v += g * w * x[e * M_in + m_in];
      }
      atomicAdd(&grad_basis[e * S + s], v);
    }
  }
rusty1s's avatar
rusty1s committed
223
224
225
226
227
228
}

torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
                                             torch::Tensor x,
                                             torch::Tensor weight,
                                             torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
229
230
231
232
  CHECK_CUDA(grad_out);
  CHECK_CUDA(x);
  CHECK_CUDA(weight);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
233
  cudaSetDevice(grad_out.get_device());
rusty1s's avatar
rusty1s committed
234
235
236
237
238
239
240
241
242
243
244
245
246

  CHECK_INPUT(x.size(1) == weight.size(1));
  CHECK_INPUT(grad_out.size(1) == weight.size(2));

  auto E = grad_out.size(0);
  auto M_in = x.size(1);
  auto M_out = grad_out.size(1);
  auto S = weight_index.size(1);

  auto grad_basis = at::zeros({E, S}, grad_out.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
247
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
248
249
250
251
252
253
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto x_data = x.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto grad_basis_data = grad_basis.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
254
255
256
257
    spline_weighting_bw_basis_kernel<scalar_t>
        <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
            grad_out_data, x_data, weight_data, weight_index_data,
            grad_basis_data, E, M_in, M_out, S, grad_out.numel());
rusty1s's avatar
rusty1s committed
258
259
260
  });

  return grad_basis;
rusty1s's avatar
rusty1s committed
261
}