weighting_kernel.cu 8.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>

#include "atomics.cuh"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void
weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
                    at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
                    at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
                    at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
                    at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index,
                    size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    int64_t e = i / out.sizes[1], m_out = i % out.sizes[1];
    auto S = basis.sizes[1];
    scalar_t v = 0;

    for (ptrdiff_t s = 0; s < S; s++) {
      auto b = basis.data[e * S + s];
      auto wi = weight_index.data[e * S + s];
      for (ptrdiff_t m_in = 0; m_in < x.sizes[1]; m_in++) {
        auto tmp =
            weight.data[wi * weight.strides[0] + m_in * weight.strides[1] +
                        m_out * weight.strides[2]];
        tmp *= b * x.data[e * x.strides[0] + m_in * x.strides[1]];
        v += tmp;
      }
    }
    out.data[e * out.sizes[1] + m_out] = v;
  }
}

at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
                             at::Tensor weight_index) {
  auto E = x.size(0), M_out = weight.size(2);
  auto out = at::empty({E, M_out}, x.type());
  AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] {
    weighting_fw_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>(
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
        at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
        out.numel());
  });
  return out;
}

template <typename scalar_t>
__global__ void weighting_bw_x_kernel(
    at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_x,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
    at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    int64_t e = i / grad_x.sizes[1], m_in = i % grad_x.sizes[1];
    auto S = basis.sizes[1];
    scalar_t v = 0;

    for (ptrdiff_t s = 0; s < S; s++) {
      auto b = basis.data[e * S + s];
      auto wi = weight_index.data[e * S + s];
      for (ptrdiff_t m_out = 0; m_out < grad_out.sizes[1]; m_out++) {
        auto tmp =
            weight.data[wi * weight.strides[0] + m_out * weight.strides[1] +
                        m_in * weight.strides[2]];
        tmp *= b *
               grad_out
                   .data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
        v += tmp;
      }
    }
    grad_x.data[e * grad_x.sizes[1] + m_in] = v;
  }
}

at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
                               at::Tensor basis, at::Tensor weight_index) {
  auto E = grad_out.size(0), M_in = weight.size(1);
  auto grad_x = at::empty({E, M_in}, grad_out.type());
  weight = weight.transpose(1, 2).contiguous();
  AT_DISPATCH_FLOATING_TYPES(grad_x.type(), "weighting_bw_x", [&] {
    weighting_bw_x_kernel<scalar_t><<<BLOCKS(grad_x.numel()), THREADS>>>(
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_x),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
        at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
        grad_x.numel());
  });
  return grad_x;
}

template <typename scalar_t>
__global__ void weighting_bw_w_kernel(
    at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_weight,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> basis,
    at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    int64_t e = i / grad_out.sizes[1], m_out = i % grad_out.sizes[1];
    int64_t S = basis.sizes[1], M_in = x.sizes[1], M_out = grad_out.sizes[1];

    auto g =
        grad_out.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
    for (ptrdiff_t s = 0; s < S; s++) {
      auto b = basis.data[e * S + s];
      auto wi = weight_index.data[e * S + s];
      for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
        auto v = g * b * x.data[e * x.strides[0] + m_in * x.strides[1]];
        atomicAdd(&grad_weight.data[wi * M_in * M_out + m_in * M_out + m_out],
                  v);
      }
    }
  }
}

at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
                               at::Tensor basis, at::Tensor weight_index,
                               int64_t K) {
  auto M_in = x.size(1), M_out = grad_out.size(1);
  auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.type());
  AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] {
    weighting_bw_w_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_weight),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(basis),
        at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
        grad_out.numel());
  });
  return grad_weight;
}

template <typename scalar_t>
__global__ void weighting_bw_b_kernel(
    at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_basis,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> grad_out,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> x,
    at::cuda::detail::TensorInfo<scalar_t, int64_t> weight,
    at::cuda::detail::TensorInfo<int64_t, int64_t> weight_index, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    int64_t e = i / grad_out.sizes[1], m_out = i % grad_out.sizes[1];
    auto S = grad_basis.sizes[1];

    auto g =
        grad_out.data[e * grad_out.strides[0] + m_out * grad_out.strides[1]];
    for (ptrdiff_t s = 0; s < S; s++) {
      scalar_t v = 0;
      auto wi = weight_index.data[e * S + s];
      for (ptrdiff_t m_in = 0; m_in < x.sizes[1]; m_in++) {
        auto w = weight.data[wi * weight.strides[0] + m_in * weight.strides[1] +
                             m_out * weight.strides[2]];
        v += g * w * x.data[e * x.strides[0] + m_in * x.strides[1]];
      }
      atomicAdd(&grad_basis.data[e * S + s], v);
    }
  }
}

at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
                               at::Tensor weight, at::Tensor weight_index) {
  auto E = x.size(0), S = weight_index.size(1);
  auto grad_basis = at::zeros({E, S}, grad_out.type());
  AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] {
    weighting_bw_b_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_basis),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_out),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(x),
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(weight),
        at::cuda::detail::getTensorInfo<int64_t, int64_t>(weight_index),
        grad_out.numel());
  });
  return grad_basis;
}