encoding_kernel.cu 12.5 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
#include <ATen/ATen.h>
#include <vector>

#include "common.h"
#include "device_tensor.h"

namespace {

template<typename DType, typename Acctype>
struct AggOp {
  __device__ AggOp(DeviceTensor<DType, 3> a,
                   DeviceTensor<DType, 3> x,
                   DeviceTensor<DType, 2> c) : A(a), X(x), C(c) {}
  __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) {
    return ScalarConvert<DType, Acctype>::to(A[b][i][k] * (X[b][i][d] - C[k][d]));
  }
  DeviceTensor<DType, 3> A;
  DeviceTensor<DType, 3> X;
  DeviceTensor<DType, 2> C;
};

template<typename DType, typename Acctype>
struct AggBackOp {
  __device__ AggBackOp(DeviceTensor<DType, 3> g,
                       DeviceTensor<DType, 3> x,
                       DeviceTensor<DType, 2> c) : G(g), X(x), C(c) {}
  __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) {
    return ScalarConvert<DType, Acctype>::to(G[b][k][d] * (X[b][i][d] - C[k][d]));
  }
  DeviceTensor<DType, 3> G;
  DeviceTensor<DType, 3> X;
  DeviceTensor<DType, 2> C;
};

template<typename DType, typename Acctype>
struct SL2Op {
  __device__ SL2Op(DeviceTensor<DType, 3> x,
                   DeviceTensor<DType, 2> c) : X(x), C(c) {}
  __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) 
  {
      DType r = X[b][i][d] - C[k][d];
      return ScalarConvert<DType, Acctype>::to(r * r);
  }
  DeviceTensor<DType, 3> X;
  DeviceTensor<DType, 2> C;
};

template<typename DType, typename Acctype>
struct SL2GradXOp {
  __device__ SL2GradXOp(
    DeviceTensor<DType, 3> gsl,
    DeviceTensor<DType, 3> x,
    DeviceTensor<DType, 2> c,
    DeviceTensor<DType, 1> s
  ) : GSL(gsl), X(x), C(c), S(s) {}
  __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) 
  {
    return ScalarConvert<DType, Acctype>::to(
      2 * S[k] * GSL[b][i][k] * (X[b][i][d]-C[k][d]));
  }
  DeviceTensor<DType, 3> GSL;
  DeviceTensor<DType, 3> X;
  DeviceTensor<DType, 2> C;
  DeviceTensor<DType, 1> S;
};


template<typename T, typename Op>
__device__ T reduceN(
    Op op, int b, int k, int d, int N) {
  T sum = 0;
  for (int x = threadIdx.x; x < N; x += blockDim.x) {
      sum += op(b,x,k,d);
  }
  // sum over NumThreads within a warp
  sum = warpSum(sum);

  // 'transpose', and reduce within warp again
  __shared__ T shared[32];

  __syncthreads();
  if (threadIdx.x % WARP_SIZE == 0) {
      if (threadIdx.x / WARP_SIZE < 32) {
              shared[threadIdx.x / WARP_SIZE] = sum;
      }
  }
  if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
      // zero out the other entries in shared
      shared[threadIdx.x] = (T) 0;
  }
  __syncthreads();
  if (threadIdx.x / WARP_SIZE == 0) {
      sum = warpSum(shared[threadIdx.x]);
      if (threadIdx.x == 0) {
          shared[0] = sum;
      }
  }
  __syncthreads();

  // Everyone picks it up, should be broadcast into the whole gradInput
  return shared[0];
}

template<typename T, typename Op>
__device__ T reduceD(
    Op op, int b, int i, int k, int D) {
  T sum = 0;
  for (int x = threadIdx.x; x < D; x += blockDim.x) {
      sum += op(b,i,k,x);
  }
  // sum over NumThreads within a warp
  sum = warpSum(sum);

  // 'transpose', and reduce within warp again
  __shared__ T shared[32];

  __syncthreads();
  if (threadIdx.x % WARP_SIZE == 0) {
      if (threadIdx.x / WARP_SIZE < 32) {
              shared[threadIdx.x / WARP_SIZE] = sum;
      }
  }
  if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
      // zero out the other entries in shared
      shared[threadIdx.x] = (T) 0;
  }
  __syncthreads();
  if (threadIdx.x / WARP_SIZE == 0) {
      sum = warpSum(shared[threadIdx.x]);
      if (threadIdx.x == 0) {
          shared[0] = sum;
      }
  }
  __syncthreads();

  // Everyone picks it up, should be broadcast into the whole gradInput
  return shared[0];
}

template<typename T, typename Op>
__device__ T reduceK(
    Op op, int b, int i, int d, int K) {
  T sum = 0;
  for (int x = threadIdx.x; x < K; x += blockDim.x) {
    sum += op(b,i,x,d);
  }
  // sum over NumThreads within a warp
  sum = warpSum(sum);

  // 'transpose', and reduce within warp again
  __shared__ T shared[32];

  __syncthreads();
  if (threadIdx.x % WARP_SIZE == 0) {
    if (threadIdx.x / WARP_SIZE < 32) {
            shared[threadIdx.x / WARP_SIZE] = sum;
    }
  }
  if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
    // zero out the other entries in shared
    shared[threadIdx.x] = (T) 0;
  }
  __syncthreads();
  if (threadIdx.x / WARP_SIZE == 0) {
    sum = warpSum(shared[threadIdx.x]);
    if (threadIdx.x == 0) {
      shared[0] = sum;
    }
  }
  __syncthreads();

  // Everyone picks it up, should be broadcast into the whole gradInput
  return shared[0];
}

template<typename T, typename Op>
__device__ T reduceBN(
    Op op, 
    int k, int d, int B, int N) {
  T sum = 0;
  for (int batch = 0; batch < B; ++batch) {
    for (int x = threadIdx.x; x < N; x += blockDim.x) {
        sum += op(batch,x,k,d);
    }
  }
  // sum over NumThreads within a warp
  sum = warpSum(sum);
  // 'transpose', and reduce within warp again
  __shared__ T shared[32];

  __syncthreads();
  if (threadIdx.x % WARP_SIZE == 0) {
    if (threadIdx.x / WARP_SIZE < 32) {
            shared[threadIdx.x / WARP_SIZE] = sum;
    }
  }
  if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
    // zero out the other entries in shared
    shared[threadIdx.x] = (T) 0;
  }
  __syncthreads();
  if (threadIdx.x / WARP_SIZE == 0) {
    sum = warpSum(shared[threadIdx.x]);
    if (threadIdx.x == 0) {
      shared[0] = sum;
    }
  }
  __syncthreads();

  // Everyone picks it up, should be broadcast into the whole gradInput
  return shared[0];
}

template<typename DType, typename Acctype>
__global__ void Aggregate_Forward_kernel (
    DeviceTensor<DType, 3> E,
    DeviceTensor<DType, 3> A,
    DeviceTensor<DType, 3> X,
    DeviceTensor<DType, 2> C) {
  /* declarations of the variables */
  int b, k, d, N;
  /* Get the index and channels */ 
  b = blockIdx.z;
  d = blockIdx.x;
  k = blockIdx.y;
  N = X.getSize(1);
  /* main operation */
  AggOp<DType, Acctype> g(A,X,C);
  E[b][k][d] = reduceN<Acctype>(g, b, k, d, N);
}

template<typename DType, typename Acctype>
__global__ void Aggregate_Backward_kernel (
    DeviceTensor<DType, 3> GA,
    DeviceTensor<DType, 3> GE,
    DeviceTensor<DType, 3> A,
    DeviceTensor<DType, 3> X,
    DeviceTensor<DType, 2> C) {
  /* declarations of the variables */
  int b, k, i, D;
  /* Get the index and channels */ 
  b = blockIdx.z;
  i = blockIdx.y;
  k = blockIdx.x;
  D = GE.getSize(2);
  /* main operation */
  AggBackOp<DType, Acctype> g(GE,X,C);
  GA[b][i][k] = reduceD<Acctype>(g, b, i, k, D);
}

template<typename DType, typename Acctype>
__global__ void ScaledL2_Forward_kernel (
    DeviceTensor<DType, 3> SL,
    DeviceTensor<DType, 3> X,
    DeviceTensor<DType, 2> C,
    DeviceTensor<DType, 1> S) {
  /* declarations of the variables */
  int b, k, i, D;
  /* Get the index and channels */ 
  b = blockIdx.z;
  k = blockIdx.x;
  i = blockIdx.y;
  D = X.getSize(2);
  /* main operation */
  SL2Op<DType, Acctype> g(X,C);
  SL[b][i][k] = S[k] * reduceD<Acctype>(g,b,i,k,D);;
}

template<typename DType, typename Acctype>
__global__ void ScaledL2_GradX_kernel (
    DeviceTensor<DType, 3> GSL,
    DeviceTensor<DType, 3> GX,
    DeviceTensor<DType, 3> X,
    DeviceTensor<DType, 2> C,
    DeviceTensor<DType, 1> S) {
  /* declarations of the variables */
  int b, d, i, K;
  /* Get the index and channels */ 
  b = blockIdx.z;
  d = blockIdx.x;
  i = blockIdx.y;
  K = C.getSize(0);
  /* main operation */
  SL2GradXOp<DType, Acctype> g(GSL,X,C,S);
  GX[b][i][d] = reduceK<Acctype>(g,b,i,d,K);
}

template<typename DType, typename Acctype>
__global__ void ScaledL2_GradC_kernel (
    DeviceTensor<DType, 3> GSL,
    DeviceTensor<DType, 2> GC,
    DeviceTensor<DType, 3> X,
    DeviceTensor<DType, 2> C,
    DeviceTensor<DType, 1> S) {
  /* declarations of the variables */
  int k, d, B, N;
  /* Get the index and channels */ 
  d = blockIdx.x;
  k = blockIdx.y;
  B = X.getSize(0);
  N = X.getSize(1);
  /* main operation */
  SL2GradXOp<DType, Acctype> g(GSL,X,C,S);
  GC[k][d] = - reduceBN<Acctype>(g, k, d, B, N);
}

}// namespace

at::Tensor Aggregate_Forward_CUDA(
    const at::Tensor A_,
    const at::Tensor X_,
    const at::Tensor C_) {
  /* Device tensors */
  auto E_ = A_.type().tensor({A_.size(0), C_.size(0), C_.size(1)}).zero_(); 
  cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
  // B, K, D
  dim3 blocks(C_.size(1), C_.size(0), X_.size(0));
  dim3 threads(getNumThreads(X_.size(1)));

  AT_DISPATCH_FLOATING_TYPES(A_.type(), "Aggregate_Forward_CUDA", ([&] {
    DeviceTensor<scalar_t, 3> E = devicetensor<scalar_t, 3>(E_);
    DeviceTensor<scalar_t, 3> A = devicetensor<scalar_t, 3>(A_);
    DeviceTensor<scalar_t, 3> X = devicetensor<scalar_t, 3>(X_);
    DeviceTensor<scalar_t, 2> C = devicetensor<scalar_t, 2>(C_);
    /* kernel function */
    Aggregate_Forward_kernel<scalar_t, scalar_t>
      <<<blocks, threads, 0, stream>>>(E, A, X, C);
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return E_;
}

std::vector<at::Tensor> Aggregate_Backward_CUDA(
    const at::Tensor GE_,
    const at::Tensor A_,
    const at::Tensor X_,
    const at::Tensor C_) {
  auto gradA_ = at::zeros_like(A_);
  auto gradX_ = at::bmm(A_, GE_);
  auto gradC_ = (-GE_ * A_.sum(1).unsqueeze(2)).sum(0);
  cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
  // B, K, D
  dim3 blocks(C_.size(0), X_.size(1), X_.size(0));
  dim3 threads(getNumThreads(C_.size(1)));
  AT_DISPATCH_FLOATING_TYPES(A_.type(), "Aggregate_Backward_CUDA", ([&] {
    /* Device tensors */
    DeviceTensor<scalar_t, 3> GA = devicetensor<scalar_t, 3>(gradA_);
    DeviceTensor<scalar_t, 3> GE = devicetensor<scalar_t, 3>(GE_);
    DeviceTensor<scalar_t, 3> A = devicetensor<scalar_t, 3>(A_);
    DeviceTensor<scalar_t, 3> X = devicetensor<scalar_t, 3>(X_);
    DeviceTensor<scalar_t, 2> C = devicetensor<scalar_t, 2>(C_);
    Aggregate_Backward_kernel<scalar_t, scalar_t>
      <<<blocks, threads, 0, stream>>> (GA, GE, A, X, C);
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return {gradA_, gradX_, gradC_};
}

at::Tensor ScaledL2_Forward_CUDA(
    const at::Tensor X_,
    const at::Tensor C_,
    const at::Tensor S_) {
  auto SL_ = X_.type().tensor({X_.size(0), X_.size(1), C_.size(0)}).zero_();
  cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
  dim3 blocks(C_.size(0), X_.size(1), X_.size(0));
  dim3 threads(getNumThreads(C_.size(1)));

  AT_DISPATCH_FLOATING_TYPES(X_.type(), "ScaledL2_Forward_CUDA", ([&] {
    /* Device tensors */
    DeviceTensor<scalar_t, 3> SL = devicetensor<scalar_t, 3>(SL_);
    DeviceTensor<scalar_t, 3> X = devicetensor<scalar_t, 3>(X_);
    DeviceTensor<scalar_t, 2> C = devicetensor<scalar_t, 2>(C_);
    DeviceTensor<scalar_t, 1> S = devicetensor<scalar_t, 1>(S_);
    /* kernel function */
    ScaledL2_Forward_kernel<scalar_t, scalar_t>
      <<<blocks, threads, 0, stream>>> (SL, X, C, S);
  }));
  AT_ASSERT(cudaGetLastError() == cudaSuccess);
  return SL_;
}

std::vector<at::Tensor> ScaledL2_Backward_CUDA(
    const at::Tensor GSL_,
    const at::Tensor X_,
    const at::Tensor C_,
    const at::Tensor S_,
    const at::Tensor SL_) {
  auto GX_ = at::zeros_like(X_);
  auto GC_ = at::zeros_like(C_);
  /* kernel function */
  cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
  dim3 blocks1(X_.size(2), X_.size(1), X_.size(0));
  dim3 threads1(getNumThreads(C_.size(0)));
  dim3 blocks2(C_.size(1), C_.size(0));
  dim3 threads2(getNumThreads(X_.size(1)));
  //std::vector<int> size{ 1, 1, K};
  //auto GS_ = GSL_ * (SL_ / at::_unsafe_view(S_, size))
  auto GS_ = (GSL_ * (SL_ / S_.view({1, 1, C_.size(0)}))).sum(0).sum(0);
  AT_DISPATCH_FLOATING_TYPES(X_.type(), "ScaledL2_Backward_CUDA", ([&] {
    /* Device tensors */
    DeviceTensor<scalar_t, 3> GSL = devicetensor<scalar_t, 3>(GSL_);
    DeviceTensor<scalar_t, 3> GX = devicetensor<scalar_t, 3>(GX_);
    DeviceTensor<scalar_t, 2> GC = devicetensor<scalar_t, 2>(GC_);
    DeviceTensor<scalar_t, 3> X = devicetensor<scalar_t, 3>(X_);
    DeviceTensor<scalar_t, 2> C = devicetensor<scalar_t, 2>(C_);
    DeviceTensor<scalar_t, 1> S = devicetensor<scalar_t, 1>(S_);
    ScaledL2_GradX_kernel<scalar_t, scalar_t>
      <<<blocks1, threads1, 0, stream>>> (GSL, GX, X, C, S);
    AT_ASSERT(cudaGetLastError() == cudaSuccess);
    ScaledL2_GradC_kernel<scalar_t, scalar_t>
      <<<blocks2, threads2, 0, stream>>> (GSL, GC, X, C, S);
    AT_ASSERT(cudaGetLastError() == cudaSuccess);
  }));
  return {GX_, GC_, GS_};
}