"docs/source/en/api/models/vq.mdx" did not exist on "cdf2ae8a8426d198a108242dc933c39763c8ccc3"
layer_norm_cuda_kernel.cu 25.8 KB
Newer Older
1
2
3
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
4
#include "ATen/cuda/DeviceUtils.cuh"
5
6
7
8

#include <cuda.h>
#include <cuda_runtime.h>

9
10
#include "type_shim.h"

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
template<typename U> __device__
void cuWelfordOnlineSum(
  const U curr,
  U& mu,
  U& sigma2,
  U& count)
{
  count = count + U(1);
  U delta = curr - mu;
  U lmean = mu + delta / count;
  mu = lmean;
  U delta2 = curr - lmean;
  sigma2 = sigma2 + delta * delta2;
}

template<typename U> __device__
void cuChanOnlineSum(
  const U muB,
  const U sigma2B,
  const U countB,
  U& mu,
  U& sigma2,
  U& count)
{
  U delta = muB - mu;
  U nA = count;
  U nB = countB;
  count = count + countB;
  U nX = count;
  if (nX > U(0)) {
    nA = nA / nX;
    nB = nB / nX;
    mu = nA*mu + nB*muB;
    sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
  } else {
    mu = U(0);
    sigma2 = U(0);
  }
}

template<typename T, typename U> __device__
void cuWelfordMuSigma2(
  const T* __restrict__ vals,
  const int n1,
  const int n2,
56
  const int i1,
57
58
  U& mu,
  U& sigma2,
Masaki Kozuki's avatar
Masaki Kozuki committed
59
  U* buf)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensor is contiguous
  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  //
  // compute variance and mean over n2
  U count = U(0);
  mu= U(0);
  sigma2 = U(0);
  if (i1 < n1) {
    // one warp normalizes one n1 index,
    // synchronization is implicit
    // initialize with standard Welford algorithm
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    const T* lvals = vals + i1*n2;
    int l = 4*thrx;
    for (;  l+3 < n2;  l+=4*numx) {
      for (int k = 0;  k < 4;  ++k) {
        U curr = static_cast<U>(lvals[l+k]);
        cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
      }
    }
    for (;  l < n2;  ++l) {
      U curr = static_cast<U>(lvals[l]);
      cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
    }
    // intra-warp reductions
    for (int l = 0;  l <= 4;  ++l) {
      int srcLaneB = (threadIdx.x+(1<<l))&31;
Ashish Farmer's avatar
Ashish Farmer committed
91
92
93
      U muB = WARP_SHFL(mu, srcLaneB, 32);
      U countB = WARP_SHFL(count, srcLaneB, 32);
      U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
      cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
    }
    // threadIdx.x == 0 has correct values for each warp
    // inter-warp reductions
    if (blockDim.y > 1) {
      U* ubuf = (U*)buf;
      U* ibuf = (U*)(ubuf + blockDim.y);
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_y = threadIdx.y - offset;
          ubuf[2*wrt_y] = mu;
          ubuf[2*wrt_y+1] = sigma2;
          ibuf[wrt_y] = count;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.x == 0 && threadIdx.y < offset) {
          U muB = ubuf[2*threadIdx.y];
          U sigma2B = ubuf[2*threadIdx.y+1];
          U countB = ibuf[threadIdx.y];
          cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
        }
        __syncthreads();
      }
      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
      if (threadIdx.x == 0 && threadIdx.y == 0) {
        ubuf[0] = mu;
        ubuf[1] = sigma2;
      }
      __syncthreads();
      mu = ubuf[0];
      sigma2 = ubuf[1]/U(n2);
      // don't care about final value of count, we know count == n2
    } else {
Ashish Farmer's avatar
Ashish Farmer committed
129
130
      mu = WARP_SHFL(mu, 0, 32);
      sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
131
132
133
134
135
136
137
138
139
    }
  }
}

template<> __device__
void cuWelfordMuSigma2(
  const at::Half* __restrict__ vals,
  const int n1,
  const int n2,
140
  const int i1,
141
142
  float& mu,
  float& sigma2,
Masaki Kozuki's avatar
Masaki Kozuki committed
143
  float* buf)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensor is contiguous
  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  //
  // compute variance and mean over n2
  float count = 0.0f;
  mu= float(0);
  sigma2 = float(0);
  if (i1 < n1) {
    // one warp normalizes one n1 index,
    // synchronization is implicit
    // initialize with standard Welford algorithm
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    const at::Half* lvals = vals + i1*n2;
    int l = 8*thrx;
    if ((((size_t)lvals)&3) != 0) {
      // 16 bit alignment
      // first thread consumes first point
      if (thrx == 0) {
        float curr = static_cast<float>(lvals[0]);
        cuWelfordOnlineSum(curr,mu,sigma2,count);
      }
      ++l;
    }
    // at this point, lvals[l] are 32 bit aligned for all threads.
    for (;  l+7 < n2;  l+=8*numx) {
      for (int k = 0;  k < 8;  k+=2) {
        float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
        cuWelfordOnlineSum(curr.x,mu,sigma2,count);
Masaki Kozuki's avatar
Masaki Kozuki committed
176
        cuWelfordOnlineSum(curr.y,mu,sigma2,count);
177
178
179
180
181
182
183
184
185
      }
    }
    for (;  l < n2;  ++l) {
      float curr = static_cast<float>(lvals[l]);
      cuWelfordOnlineSum(curr,mu,sigma2,count);
    }
    // intra-warp reductions
    for (int l = 0;  l <= 4;  ++l) {
      int srcLaneB = (threadIdx.x+(1<<l))&31;
Ashish Farmer's avatar
Ashish Farmer committed
186
187
188
      float muB = WARP_SHFL(mu, srcLaneB, 32);
      float countB = WARP_SHFL(count, srcLaneB, 32);
      float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
      cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
    }
    // threadIdx.x == 0 has correct values for each warp
    // inter-warp reductions
    if (blockDim.y > 1) {
      float* ubuf = (float*)buf;
      float* ibuf = (float*)(ubuf + blockDim.y);
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_y = threadIdx.y - offset;
          ubuf[2*wrt_y] = mu;
          ubuf[2*wrt_y+1] = sigma2;
          ibuf[wrt_y] = count;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.x == 0 && threadIdx.y < offset) {
          float muB = ubuf[2*threadIdx.y];
          float sigma2B = ubuf[2*threadIdx.y+1];
          float countB = ibuf[threadIdx.y];
          cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
        }
        __syncthreads();
      }
      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
      if (threadIdx.x == 0 && threadIdx.y == 0) {
        ubuf[0] = mu;
        ubuf[1] = sigma2;
      }
      __syncthreads();
      mu = ubuf[0];
      sigma2 = ubuf[1]/float(n2);
      // don't care about final value of count, we know count == n2
    } else {
Ashish Farmer's avatar
Ashish Farmer committed
224
225
      mu = WARP_SHFL(mu, 0, 32);
      sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
226
227
228
229
230
231
232
    }
  }
}

template<typename U> U rsqrt(U v) {
  return U(1) / sqrt(v);
}
233
234
235
236
237
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
  return rsqrtf(v);
}
#else
238
239
240
template<> float rsqrt(float v) {
  return rsqrtf(v);
}
241
#endif
242
243
244
245
246
247
248
template<> double rsqrt(double v) {
  return rsqrt(v);
}

namespace {
// This is the un-specialized struct.  Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
Michael Carilli's avatar
Michael Carilli committed
249
250
251
252
253
254
255
256
257
258
259
260
//  template <typename T>
//  struct SharedMemory
//  {
//      // Ensure that we won't compile any un-specialized types
//      __device__ T *getPointer()
//      {
//          extern __device__ void error(void);
//          error();
//          return NULL;
//      }
//  };
// https://github.com/NVIDIA/apex/issues/246
261
template <typename T>
Michael Carilli's avatar
Michael Carilli committed
262
struct SharedMemory;
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

template <>
struct SharedMemory <float>
{
    __device__ float *getPointer()
    {
        extern __shared__ float s_float[];
        return s_float;
    }
};

template <>
struct SharedMemory <double>
{
    __device__ double *getPointer()
    {
        extern __shared__ double s_double[];
        return s_double;
    }
};
}

Masaki Kozuki's avatar
Masaki Kozuki committed
285
286
287
template<typename T, typename U, typename V> __device__
void cuApplyLayerNorm_(
  V* __restrict__ output_vals,
288
289
290
291
292
293
  U* __restrict__ mean,
  U* __restrict__ invvar,
  const T* __restrict__ vals,
  const int n1,
  const int n2,
  const U epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
294
295
296
  const V* __restrict__ gamma,
  const V* __restrict__ beta
  )
297
298
299
300
301
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensors are contiguous
  //
302
  for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
303
304
305
    SharedMemory<U> shared;
    U* buf = shared.getPointer();
    U mu,sigma2;
306
    cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
307
    const T* lvals = vals + i1*n2;
Masaki Kozuki's avatar
Masaki Kozuki committed
308
    V* ovals = output_vals + i1*n2;
309
310
311
312
313
314
    U c_invvar = rsqrt(sigma2 + epsilon);
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    if (gamma != NULL && beta != NULL) {
      for (int i = thrx;  i < n2;  i+=numx) {
        U curr = static_cast<U>(lvals[i]);
Masaki Kozuki's avatar
Masaki Kozuki committed
315
        ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
316
317
318
319
      }
    } else {
      for (int i = thrx;  i < n2;  i+=numx) {
        U curr = static_cast<U>(lvals[i]);
Masaki Kozuki's avatar
Masaki Kozuki committed
320
        ovals[i] = static_cast<V>(c_invvar * (curr - mu));
321
322
323
324
325
326
      }
    }
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      mean[i1] = mu;
      invvar[i1] = c_invvar;
    }
eqy's avatar
eqy committed
327
    __syncthreads();
328
329
330
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
template<typename T, typename U, typename V=T> __global__
void cuApplyLayerNorm(
  V* __restrict__ output_vals,
  U* __restrict__ mean,
  U* __restrict__ invvar,
  const T* __restrict__ vals,
  const int n1,
  const int n2,
  const U epsilon,
  const V* __restrict__ gamma,
  const V* __restrict__ beta
  )
{
  cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta);
}


template<typename T, typename U, typename V> __device__
349
350
351
352
353
354
355
356
357
void cuLoadWriteStridedInputs(
    const int i1_block,
    const int thr_load_row_off,
    const int thr_load_col_off,
    const int i2_off,
    const int row_stride,
    U* warp_buf1,
    U* warp_buf2,
    const T* input,
Masaki Kozuki's avatar
Masaki Kozuki committed
358
    const V* dout,
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    const int i1_end,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar
    )
{
  int i1 = i1_block+thr_load_row_off;
  if (i1 < i1_end) {
    U curr_mean = mean[i1];
    U curr_invvar = invvar[i1];
    for (int k = 0;  k < blockDim.y;  ++k) {
      int i2 = i2_off + k;
      int load_idx = i1*n2+i2;
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      if (i2<n2) {
        U curr_input = static_cast<U>(input[load_idx]);
Masaki Kozuki's avatar
Masaki Kozuki committed
375
376
377
        U curr_dout = static_cast<U>(dout[load_idx]);
        warp_buf1[write_idx] = curr_dout;
        warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
378
379
380
381
382
383
384
385
386
387
388
389
390
391
      } else {
        warp_buf1[write_idx] = U(0);
        warp_buf2[write_idx] = U(0);
      }
    }
  } else {
    for (int k = 0;  k < blockDim.y;  ++k) {
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      warp_buf1[write_idx] = U(0);
      warp_buf2[write_idx] = U(0);
    }
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
392
template<typename T, typename U, typename V> __device__
393
394
395
396
397
398
399
400
401
void cuLoadAddStridedInputs(
    const int i1_block,
    const int thr_load_row_off,
    const int thr_load_col_off,
    const int i2_off,
    const int row_stride,
    U* warp_buf1,
    U* warp_buf2,
    const T* input,
Masaki Kozuki's avatar
Masaki Kozuki committed
402
    const V* dout,
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    const int i1_end,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar
    )
{
  int i1 = i1_block+thr_load_row_off;
  if (i1 < i1_end) {
    U curr_mean = mean[i1];
    U curr_invvar = invvar[i1];
    for (int k = 0;  k < blockDim.y;  ++k) {
      int i2 = i2_off + k;
      int load_idx = i1*n2+i2;
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      if (i2<n2) {
        U curr_input = static_cast<U>(input[load_idx]);
Masaki Kozuki's avatar
Masaki Kozuki committed
419
420
421
        U curr_dout = static_cast<U>(dout[load_idx]);
        warp_buf1[write_idx] += curr_dout;
        warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
422
423
424
425
426
      }
    }
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
427
template<typename T, typename U, typename V> __global__
428
void cuComputePartGradGammaBeta(
Masaki Kozuki's avatar
Masaki Kozuki committed
429
    const V* __restrict__ dout,
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    const T* __restrict__ input,
    const int n1,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar,
    U epsilon,
    U* part_grad_gamma,
    U* part_grad_beta)
{
    const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
    const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
    const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
    const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
    const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
    const int row_stride = blockDim.x+1;
    const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
    const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
    const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
    SharedMemory<U> shared;
    U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
    U* warp_buf1 = (U*)buf;
    U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
    // compute partial sums from strided inputs
    // do this to increase number of loads in flight
    cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
    for (int i1_block = i1_beg+blockDim.y*blockDim.y;  i1_block < i1_end;  i1_block+=blockDim.y*blockDim.y) {
      cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
    }
    __syncthreads();
    // inter-warp reductions
    // sum within each warp
    U acc1 = U(0);
    U acc2 = U(0);
    for (int k = 0;  k < blockDim.y;  ++k) {
      int row1 = threadIdx.y + k*blockDim.y;
      int idx1 = row1*row_stride + threadIdx.x;
      acc1 += warp_buf1[idx1];
      acc2 += warp_buf2[idx1];
    }
    warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
    warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
    __syncthreads();
    // sum all warps
    for (int offset = blockDim.y/2;  offset > 1;  offset /= 2) {
      if (threadIdx.y < offset) {
        int row1 = threadIdx.y;
Masaki Kozuki's avatar
Masaki Kozuki committed
476
477
478
479
480
        int row2 = threadIdx.y + offset;
        int idx1 = row1*row_stride + threadIdx.x;
        int idx2 = row2*row_stride + threadIdx.x;
        warp_buf1[idx1] += warp_buf1[idx2];
        warp_buf2[idx1] += warp_buf2[idx2];
481
482
483
484
485
486
487
488
489
490
491
492
493
494
      }
      __syncthreads();
    }
    int i2 = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadIdx.y == 0 && i2 < n2) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + 1;
      int idx1 = row1*row_stride + threadIdx.x;
      int idx2 = row2*row_stride + threadIdx.x;
      part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
      part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
    }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
495
template<typename U, typename V> __global__
496
497
498
499
500
501
void cuComputeGradGammaBeta(
    const U* part_grad_gamma,
    const U* part_grad_beta,
    const int part_size,
    const int n1,
    const int n2,
Masaki Kozuki's avatar
Masaki Kozuki committed
502
503
    V* grad_gamma,
    V* grad_beta)
504
505
506
{
    // sum partial gradients for gamma and beta
    SharedMemory<U> shared;
Masaki Kozuki's avatar
Masaki Kozuki committed
507
    U* buf = shared.getPointer();
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    int i2 = blockIdx.x * blockDim.x + threadIdx.x;
    if (i2 < n2) {
      // each warp does sequential reductions until reduced part_size is num_warps
      int num_warp_reductions = part_size / blockDim.y;
      U sum_gamma = U(0);
      U sum_beta = U(0);
      const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
      const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
      for (int warp_offset = 0;  warp_offset < num_warp_reductions;  ++warp_offset) {
        sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
        sum_beta += part_grad_beta_ptr[warp_offset*n2];
      }
      // inter-warp reductions
      const int nbsize3 = blockDim.x * blockDim.y / 2;
      for (int offset = blockDim.y/2;  offset >= 1;  offset /= 2) {
        // top half write to shared memory
        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
          buf[write_idx] = sum_gamma;
          buf[write_idx+nbsize3] = sum_beta;
        }
        __syncthreads();
        // bottom half sums
        if (threadIdx.y < offset) {
          const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
          sum_gamma += buf[read_idx];
          sum_beta += buf[read_idx+nbsize3];
        }
        __syncthreads();
      }
      // write out fully summed gradients
      if (threadIdx.y == 0) {
        grad_gamma[i2] = sum_gamma;
        grad_beta[i2] = sum_beta;
      }
    }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
546
template<typename T, typename U, typename V> __global__
547
void cuComputeGradInput(
Masaki Kozuki's avatar
Masaki Kozuki committed
548
    const V* __restrict__ dout,
549
550
551
552
553
554
    const T* __restrict__ input,
    const int n1,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar,
    U epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
555
    const V* gamma,
556
557
    T* grad_input)
{
558
  for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
559
560
561
562
563
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
    const U c_mean = mean[i1];
    const U c_invvar = invvar[i1];
    const T* k_input = input + i1*n2;
Masaki Kozuki's avatar
Masaki Kozuki committed
564
    const V* k_dout = dout + i1*n2;
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    if (gamma != NULL) {
      int l = 4*thrx;
      for (;  l+3 < n2;  l+=4*numx) {
        for (int k = 0;  k < 4;  ++k) {
          const U c_h = static_cast<U>(k_input[l+k]);
          const U c_loss = static_cast<U>(k_dout[l+k]);
          sum_loss1 += c_loss * gamma[l+k];
          sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
        }
      }
      for (;  l < n2;  ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss * gamma[l];
        sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
      }
    } else {
      int l = 4*thrx;
      for (;  l+3 < n2;  l+=4*numx) {
        for (int k = 0;  k < 4;  ++k) {
          const U c_h = static_cast<U>(k_input[l+k]);
          const U c_loss = static_cast<U>(k_dout[l+k]);
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        }
      }
      for (;  l < n2;  ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
    }
    // intra-warp reductions
    for (int mask = blockDim.x/2;  mask > 0;  mask /= 2) {
Ashish Farmer's avatar
Ashish Farmer committed
602
603
      sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
      sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
604
605
606
607
    }
    // inter-warp reductions
    if (blockDim.y > 1) {
      SharedMemory<U> shared;
Masaki Kozuki's avatar
Masaki Kozuki committed
608
      U* buf = shared.getPointer();
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
          buf[2*wrt_i] = sum_loss1;
          buf[2*wrt_i+1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
          sum_loss1 += buf[2*read_i];
          sum_loss2 += buf[2*read_i+1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
        buf[2*threadIdx.x] = sum_loss1;
        buf[2*threadIdx.x+1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y !=0) {
        sum_loss1 = buf[2*threadIdx.x];
        sum_loss2 = buf[2*threadIdx.x+1];
Masaki Kozuki's avatar
Masaki Kozuki committed
633
      }
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T* k_grad_input = grad_input + i1*n2;
    if (gamma != NULL) {
      for (int l = thrx;  l < n2;  l+=numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss * gamma[l];
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx;  l < n2;  l+=numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    }
eqy's avatar
eqy committed
660
661
    // prevent race where buf is written again before reads are done
    __syncthreads();
662
663
664
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
665
template<typename T, typename U, typename V=T>
666
void HostApplyLayerNorm(
Masaki Kozuki's avatar
Masaki Kozuki committed
667
    V* output,
668
669
670
671
672
673
    U* mean,
    U* invvar,
    const T* input,
    int n1,
    int n2,
    double epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
674
675
    const V* gamma,
    const V* beta
676
677
678
679
    )
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    const dim3 threads(32,4,1);
680
681
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
Masaki Kozuki's avatar
Masaki Kozuki committed
682
683
684
    int nshared =
        threads.y > 1 ?
	    threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
685
686
	    0;
    cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
Masaki Kozuki's avatar
Masaki Kozuki committed
687
      output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
688
689
690
691
692
693
694
695
696
}

void cuda_layer_norm(
    at::Tensor* output,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
697
    #ifdef VERSION_GE_1_1
698
    at::IntArrayRef normalized_shape,
699
700
701
    #else
    at::IntList normalized_shape,
    #endif
702
703
704
705
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon)
{
706
    using namespace at;
Masaki Kozuki's avatar
Masaki Kozuki committed
707
708
709
710
711
712
713
714
715
716
717
718
    DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
        input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel",
        using accscalar_t = at::acc_type<scalar_t_in, true>;
        HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(
          output->DATA_PTR<scalar_t_out>(),
	      mean->DATA_PTR<accscalar_t>(),
          invvar->DATA_PTR<accscalar_t>(),
          input->DATA_PTR<scalar_t_in>(),
          n1,n2,
          epsilon,
          gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
          beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
719
      )
720
721
}

Masaki Kozuki's avatar
Masaki Kozuki committed
722
template<typename T, typename U=float, typename V=T>
723
void HostLayerNormGradient(
Masaki Kozuki's avatar
Masaki Kozuki committed
724
    const V* dout,
725
726
727
728
729
    const U* mean,
    const U* invvar,
    at::Tensor* input,
    int n1,
    int n2,
Masaki Kozuki's avatar
Masaki Kozuki committed
730
731
    const V* gamma,
    const V* beta,
732
733
    double epsilon,
    T* grad_input,
Masaki Kozuki's avatar
Masaki Kozuki committed
734
735
    V* grad_gamma,
    V* grad_beta
736
737
738
739
740
741
742
743
744
745
746
747
    )
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();

    if (gamma != NULL && beta != NULL) {
      // compute grad_gamma(j) and grad_beta(j)
      const int part_size = 16;
      const dim3 threads2(32,4,1);
      const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
      const int nshared2_b = threads2.x * threads2.y * sizeof(U);
      const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
Masaki Kozuki's avatar
Masaki Kozuki committed
748
749
750
751
752
753
754
      // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
      // the `cuda_layer_norm_gradient` doesn't support double.
      const auto part_grad_dtype =
        (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ?
        at::ScalarType::Float :
        input->scalar_type();
      at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
755
756
757
      at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
      cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
		      dout,
mcarilli's avatar
mcarilli committed
758
		      input->DATA_PTR<T>(),
759
760
761
762
		      n1,n2,
		      mean,
		      invvar,
		      U(epsilon),
mcarilli's avatar
mcarilli committed
763
764
		      part_grad_gamma.DATA_PTR<U>(),
		      part_grad_beta.DATA_PTR<U>());
765
766
767
768
769

      const dim3 threads3(32,8,1);
      const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
      const int nshared3 = threads3.x * threads3.y * sizeof(U);
      cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
mcarilli's avatar
mcarilli committed
770
771
		      part_grad_gamma.DATA_PTR<U>(),
		      part_grad_beta.DATA_PTR<U>(),
772
773
774
775
776
777
778
		      part_size,
		      n1,n2,
		      grad_gamma,
		      grad_beta);
    }

    // compute grad_input
779
780
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
781
782
783
784
785
786
787
    const dim3 threads1(32,4,1);
    int nshared =
	    threads1.y > 1 ?
	    threads1.y*threads1.x*sizeof(U) :
	    0;
    cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
            dout,
mcarilli's avatar
mcarilli committed
788
            input->DATA_PTR<T>(),
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
            n1,n2,
            mean,
            invvar,
            U(epsilon),
            gamma,
            grad_input);
}

void cuda_layer_norm_gradient(
    at::Tensor* dout,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
804
    #ifdef VERSION_GE_1_1
805
    at::IntArrayRef normalized_shape,
806
807
808
    #else
    at::IntList normalized_shape,
    #endif
809
810
811
812
813
814
815
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon,
    at::Tensor* grad_input,
    at::Tensor* grad_gamma,
    at::Tensor* grad_beta)
{
816
    using namespace at;
Masaki Kozuki's avatar
Masaki Kozuki committed
817
818
819
820
821
822
823
824
825
826
    // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
    DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
      input->scalar_type(), gamma == NULL ? input->scalar_type() :  gamma->scalar_type(), "cuComputeGradInput",
      using accscalar_t = at::acc_type<scalar_t_in, true>;
      HostLayerNormGradient(
        dout->DATA_PTR<scalar_t_out>(),
        mean->DATA_PTR<accscalar_t>(),
        invvar->DATA_PTR<accscalar_t>(),
        input,
        n1,n2,
ngimel's avatar
ngimel committed
827
828
            // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
            // if gamma Tensor is NULL on input.
Masaki Kozuki's avatar
Masaki Kozuki committed
829
830
831
832
833
834
835
        gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
        gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
        epsilon,
        grad_input->DATA_PTR<scalar_t_in>(),
        gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
        gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
    )
836
}