"model.properties" did not exist on "584fbdd5decfe74fd29d9497f8306c7c0c727074"
weighting_cuda.cu 8.7 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "weighting_cuda.h"
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/CUDAContext.h>

rusty1s's avatar
rusty1s committed
5
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
6
7
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void
spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
                           const scalar_t *basis, const int64_t *weight_index,
rusty1s's avatar
rusty1s committed
15
                           scalar_t *out, int64_t E, int64_t M_in,
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
                           int64_t M_out, int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;

  if (thread_idx < numel) {
rusty1s's avatar
rusty1s committed
23
    scalar_t v = (scalar_t)0.;
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37

    for (ptrdiff_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];
      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        scalar_t tmp = weight[wi * M_in * M_out + m_in * M_out + m_out];
        tmp *= b * x[e * M_in + m_in];
        v += tmp;
      }
    }
    out[thread_idx] = v;
  }
}

rusty1s's avatar
rusty1s committed
38
39
40
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
                                       torch::Tensor basis,
                                       torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
41
42
43
44
  CHECK_CUDA(x);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
45
  cudaSetDevice(x.get_device());
rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
53
54
55
56
57

  CHECK_INPUT(x.size(1) == weight.size(1));

  auto E = x.size(0);
  auto M_in = x.size(1);
  auto M_out = weight.size(2);
  auto S = basis.size(1);

  auto out = at::empty({E, M_out}, x.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
58
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
    auto x_data = x.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto out_data = out.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
65
66
67
68
    spline_weighting_fw_kernel<scalar_t>
        <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
            x_data, weight_data, basis_data, weight_index_data, out_data, E,
            M_in, M_out, S, out.numel());
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  });

  return out;
}

template <typename scalar_t>
__global__ void
spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
                             const scalar_t *basis, const int64_t *weight_index,
                             scalar_t *grad_x, int64_t E, int64_t M_in,
                             int64_t M_out, int64_t S, int64_t numel) {

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_in;
  const int64_t m_in = thread_idx % M_in;

  if (thread_idx < numel) {
    scalar_t v = (scalar_t)0.;

    for (int64_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_out = 0; m_out < M_out; m_out++) {
rusty1s's avatar
rusty1s committed
93
        scalar_t tmp = weight[wi * M_out * M_in + m_out * M_in + m_in];
rusty1s's avatar
rusty1s committed
94
95
96
97
        tmp *= b * grad_out[e * M_out + m_out];
        v += tmp;
      }
    }
rusty1s's avatar
rusty1s committed
98
    grad_x[thread_idx] = v;
rusty1s's avatar
rusty1s committed
99
  }
rusty1s's avatar
rusty1s committed
100
101
102
103
104
105
}

torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
                                         torch::Tensor weight,
                                         torch::Tensor basis,
                                         torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
106
107
108
109
  CHECK_CUDA(grad_out);
  CHECK_CUDA(weight);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
110
  cudaSetDevice(grad_out.get_device());
rusty1s's avatar
rusty1s committed
111
112
113
114
115
116
117
118
119

  CHECK_INPUT(grad_out.size(1) == weight.size(2));

  auto E = grad_out.size(0);
  auto M_in = weight.size(1);
  auto M_out = grad_out.size(1);
  auto S = basis.size(1);

  auto grad_x = at::zeros({E, M_in}, grad_out.options());
rusty1s's avatar
rusty1s committed
120
  weight = weight.transpose(1, 2).contiguous(); // Contiguous memory-access.
rusty1s's avatar
rusty1s committed
121
122
123

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
124
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
  AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto grad_x_data = grad_x.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
131
132
133
134
    spline_weighting_bw_x_kernel<scalar_t>
        <<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
            grad_out_data, weight_data, basis_data, weight_index_data,
            grad_x_data, E, M_in, M_out, S, grad_x.numel());
rusty1s's avatar
rusty1s committed
135
136
137
  });

  return grad_x;
rusty1s's avatar
rusty1s committed
138
139
}

rusty1s's avatar
rusty1s committed
140
template <typename scalar_t>
rusty1s's avatar
rusty1s committed
141
142
143
144
__global__ void spline_weighting_bw_weight_kernel(
    const scalar_t *grad_out, const scalar_t *x, const scalar_t *basis,
    const int64_t *weight_index, scalar_t *grad_weight, int64_t E, int64_t M_in,
    int64_t M_out, int64_t S, int64_t numel) {
rusty1s's avatar
rusty1s committed
145
146
147
148
149
150
151
152
153
154
155
156
157

  const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;

  if (thread_idx < numel) {
    auto g = grad_out[e * M_out + m_out];
    for (int64_t s = 0; s < S; s++) {
      const scalar_t b = basis[e * S + s];
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        auto v = g * b * x[e * M_in + m_in];
rusty1s's avatar
rusty1s committed
158
        atomAdd(&grad_weight[wi * M_in * M_out + m_in * M_out + m_out], v);
rusty1s's avatar
rusty1s committed
159
160
161
162
163
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
164
165
166
167
168
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
                                              torch::Tensor x,
                                              torch::Tensor basis,
                                              torch::Tensor weight_index,
                                              int64_t kernel_size) {
rusty1s's avatar
rusty1s committed
169
170
171
172
  CHECK_CUDA(grad_out);
  CHECK_CUDA(x);
  CHECK_CUDA(basis);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
173
  cudaSetDevice(grad_out.get_device());
rusty1s's avatar
rusty1s committed
174
175
176
177
178
179
180
181
182
183

  auto E = grad_out.size(0);
  auto M_in = x.size(1);
  auto M_out = grad_out.size(1);
  auto S = basis.size(1);

  auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
184
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
185
186
187
188
189
190
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto x_data = x.data_ptr<scalar_t>();
    auto basis_data = basis.data_ptr<scalar_t>();
    auto grad_weight_data = grad_weight.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
191
192
193
194
    spline_weighting_bw_weight_kernel<scalar_t>
        <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
            grad_out_data, x_data, basis_data, weight_index_data,
            grad_weight_data, E, M_in, M_out, S, grad_out.numel());
rusty1s's avatar
rusty1s committed
195
196
197
198
199
200
  });

  return grad_weight;
}

template <typename scalar_t>
rusty1s's avatar
rusty1s committed
201
202
203
204
__global__ void spline_weighting_bw_basis_kernel(
    const scalar_t *grad_out, const scalar_t *x, const scalar_t *weight,
    const int64_t *weight_index, scalar_t *grad_basis, int64_t E, int64_t M_in,
    int64_t M_out, int64_t S, int64_t numel) {
rusty1s's avatar
rusty1s committed
205
206

  const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
207
208
  const int64_t e = thread_idx / M_out;
  const int64_t m_out = thread_idx % M_out;
rusty1s's avatar
rusty1s committed
209
210
211
212
213
214
215
216
217
218
219
220

  if (thread_idx < numel) {
    const scalar_t g = grad_out[e * M_out + m_out];

    for (int64_t s = 0; s < S; s++) {
      scalar_t v = (scalar_t)0.;
      const int64_t wi = weight_index[e * S + s];

      for (int64_t m_in = 0; m_in < M_in; m_in++) {
        const scalar_t w = weight[wi * M_in * M_out + m_in * M_out + m_out];
        v += g * w * x[e * M_in + m_in];
      }
rusty1s's avatar
rusty1s committed
221
      atomAdd(&grad_basis[e * S + s], v);
rusty1s's avatar
rusty1s committed
222
223
    }
  }
rusty1s's avatar
rusty1s committed
224
225
226
227
228
229
}

torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
                                             torch::Tensor x,
                                             torch::Tensor weight,
                                             torch::Tensor weight_index) {
rusty1s's avatar
rusty1s committed
230
231
232
233
  CHECK_CUDA(grad_out);
  CHECK_CUDA(x);
  CHECK_CUDA(weight);
  CHECK_CUDA(weight_index);
rusty1s's avatar
rusty1s committed
234
  cudaSetDevice(grad_out.get_device());
rusty1s's avatar
rusty1s committed
235
236
237
238
239
240
241
242
243
244
245
246
247

  CHECK_INPUT(x.size(1) == weight.size(1));
  CHECK_INPUT(grad_out.size(1) == weight.size(2));

  auto E = grad_out.size(0);
  auto M_in = x.size(1);
  auto M_out = grad_out.size(1);
  auto S = weight_index.size(1);

  auto grad_basis = at::zeros({E, S}, grad_out.options());

  auto weight_index_data = weight_index.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
248
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
249
250
251
252
253
254
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
    auto grad_out_data = grad_out.data_ptr<scalar_t>();
    auto x_data = x.data_ptr<scalar_t>();
    auto weight_data = weight.data_ptr<scalar_t>();
    auto grad_basis_data = grad_basis.data_ptr<scalar_t>();

rusty1s's avatar
rusty1s committed
255
256
257
258
    spline_weighting_bw_basis_kernel<scalar_t>
        <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
            grad_out_data, x_data, weight_data, weight_index_data,
            grad_basis_data, E, M_in, M_out, S, grad_out.numel());
rusty1s's avatar
rusty1s committed
259
260
261
  });

  return grad_basis;
rusty1s's avatar
rusty1s committed
262
}