layer_norm_cuda_kernel.cu 27.6 KB
Newer Older
1
2
3
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
4
#include "ATen/cuda/DeviceUtils.cuh"
5
6
7
8

#include <cuda.h>
#include <cuda_runtime.h>

9
10
#include "type_shim.h"

11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
template<typename U> __device__
void cuWelfordOnlineSum(
  const U curr,
  U& mu,
  U& sigma2,
  U& count)
{
  count = count + U(1);
  U delta = curr - mu;
  U lmean = mu + delta / count;
  mu = lmean;
  U delta2 = curr - lmean;
  sigma2 = sigma2 + delta * delta2;
}

template<typename U> __device__
void cuChanOnlineSum(
  const U muB,
  const U sigma2B,
  const U countB,
  U& mu,
  U& sigma2,
  U& count)
{
  U delta = muB - mu;
  U nA = count;
  U nB = countB;
  count = count + countB;
  U nX = count;
  if (nX > U(0)) {
    nA = nA / nX;
    nB = nB / nX;
    mu = nA*mu + nB*muB;
    sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
  } else {
    mu = U(0);
    sigma2 = U(0);
  }
}

template<typename T, typename U> __device__
void cuWelfordMuSigma2(
  const T* __restrict__ vals,
  const int n1,
  const int n2,
57
  const int i1,
58
59
  U& mu,
  U& sigma2,
60
61
  U* buf,
  const int GPU_WARP_SIZE)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensor is contiguous
  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  //
  // compute variance and mean over n2
  U count = U(0);
  mu= U(0);
  sigma2 = U(0);
  if (i1 < n1) {
    // one warp normalizes one n1 index,
    // synchronization is implicit
    // initialize with standard Welford algorithm
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    const T* lvals = vals + i1*n2;
    int l = 4*thrx;
    for (;  l+3 < n2;  l+=4*numx) {
      for (int k = 0;  k < 4;  ++k) {
        U curr = static_cast<U>(lvals[l+k]);
        cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
      }
    }
    for (;  l < n2;  ++l) {
      U curr = static_cast<U>(lvals[l]);
      cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
    }
    // intra-warp reductions
91
92
93
94
95
96
    #pragma unroll
    for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
      U muB = WARP_SHFL_DOWN(mu, stride);
      U countB = WARP_SHFL_DOWN(count, stride);
      U sigma2B = WARP_SHFL_DOWN(sigma2, stride);
      cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    }
    // threadIdx.x == 0 has correct values for each warp
    // inter-warp reductions
    if (blockDim.y > 1) {
      U* ubuf = (U*)buf;
      U* ibuf = (U*)(ubuf + blockDim.y);
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_y = threadIdx.y - offset;
          ubuf[2*wrt_y] = mu;
          ubuf[2*wrt_y+1] = sigma2;
          ibuf[wrt_y] = count;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.x == 0 && threadIdx.y < offset) {
          U muB = ubuf[2*threadIdx.y];
          U sigma2B = ubuf[2*threadIdx.y+1];
          U countB = ibuf[threadIdx.y];
          cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
        }
        __syncthreads();
      }
      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
      if (threadIdx.x == 0 && threadIdx.y == 0) {
        ubuf[0] = mu;
        ubuf[1] = sigma2;
      }
      __syncthreads();
      mu = ubuf[0];
      sigma2 = ubuf[1]/U(n2);
      // don't care about final value of count, we know count == n2
    } else {
131
132
      mu = WARP_SHFL(mu, 0);
      sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
133
134
135
136
137
138
139
140
141
    }
  }
}

template<> __device__
void cuWelfordMuSigma2(
  const at::Half* __restrict__ vals,
  const int n1,
  const int n2,
142
  const int i1,
143
144
  float& mu,
  float& sigma2,
145
146
  float* buf,
  const int GPU_WARP_SIZE)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensor is contiguous
  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  //
  // compute variance and mean over n2
  float count = 0.0f;
  mu= float(0);
  sigma2 = float(0);
  if (i1 < n1) {
    // one warp normalizes one n1 index,
    // synchronization is implicit
    // initialize with standard Welford algorithm
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    const at::Half* lvals = vals + i1*n2;
    int l = 8*thrx;
    if ((((size_t)lvals)&3) != 0) {
      // 16 bit alignment
      // first thread consumes first point
      if (thrx == 0) {
        float curr = static_cast<float>(lvals[0]);
        cuWelfordOnlineSum(curr,mu,sigma2,count);
      }
      ++l;
    }
    // at this point, lvals[l] are 32 bit aligned for all threads.
    for (;  l+7 < n2;  l+=8*numx) {
      for (int k = 0;  k < 8;  k+=2) {
        float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
        cuWelfordOnlineSum(curr.x,mu,sigma2,count);
Masaki Kozuki's avatar
Masaki Kozuki committed
179
        cuWelfordOnlineSum(curr.y,mu,sigma2,count);
180
181
182
183
184
185
186
      }
    }
    for (;  l < n2;  ++l) {
      float curr = static_cast<float>(lvals[l]);
      cuWelfordOnlineSum(curr,mu,sigma2,count);
    }
    // intra-warp reductions
187
188
189
190
191
192
    #pragma unroll
    for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { // TODO
      float muB = WARP_SHFL_DOWN(mu, stride);
      float countB = WARP_SHFL_DOWN(count, stride);
      float sigma2B = WARP_SHFL_DOWN(sigma2, stride);
      cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    }
    // threadIdx.x == 0 has correct values for each warp
    // inter-warp reductions
    if (blockDim.y > 1) {
      float* ubuf = (float*)buf;
      float* ibuf = (float*)(ubuf + blockDim.y);
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_y = threadIdx.y - offset;
          ubuf[2*wrt_y] = mu;
          ubuf[2*wrt_y+1] = sigma2;
          ibuf[wrt_y] = count;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.x == 0 && threadIdx.y < offset) {
          float muB = ubuf[2*threadIdx.y];
          float sigma2B = ubuf[2*threadIdx.y+1];
          float countB = ibuf[threadIdx.y];
          cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
        }
        __syncthreads();
      }
      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
      if (threadIdx.x == 0 && threadIdx.y == 0) {
        ubuf[0] = mu;
        ubuf[1] = sigma2;
      }
      __syncthreads();
      mu = ubuf[0];
      sigma2 = ubuf[1]/float(n2);
      // don't care about final value of count, we know count == n2
    } else {
227
228
      mu = WARP_SHFL(mu, 0);
      sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
229
230
231
232
233
234
235
    }
  }
}

template<typename U> U rsqrt(U v) {
  return U(1) / sqrt(v);
}
236
237
238
239
240
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
  return rsqrtf(v);
}
#else
241
242
243
template<> float rsqrt(float v) {
  return rsqrtf(v);
}
244
#endif
245
246
247
248
249
250
251
template<> double rsqrt(double v) {
  return rsqrt(v);
}

namespace {
// This is the un-specialized struct.  Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
Michael Carilli's avatar
Michael Carilli committed
252
253
254
255
256
257
258
259
260
261
262
263
//  template <typename T>
//  struct SharedMemory
//  {
//      // Ensure that we won't compile any un-specialized types
//      __device__ T *getPointer()
//      {
//          extern __device__ void error(void);
//          error();
//          return NULL;
//      }
//  };
// https://github.com/NVIDIA/apex/issues/246
264
template <typename T>
Michael Carilli's avatar
Michael Carilli committed
265
struct SharedMemory;
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

template <>
struct SharedMemory <float>
{
    __device__ float *getPointer()
    {
        extern __shared__ float s_float[];
        return s_float;
    }
};

template <>
struct SharedMemory <double>
{
    __device__ double *getPointer()
    {
        extern __shared__ double s_double[];
        return s_double;
    }
};
}

Masaki Kozuki's avatar
Masaki Kozuki committed
288
289
290
template<typename T, typename U, typename V> __device__
void cuApplyLayerNorm_(
  V* __restrict__ output_vals,
291
292
293
294
295
296
  U* __restrict__ mean,
  U* __restrict__ invvar,
  const T* __restrict__ vals,
  const int n1,
  const int n2,
  const U epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
297
  const V* __restrict__ gamma,
298
299
  const V* __restrict__ beta,
  const int GPU_WARP_SIZE
Masaki Kozuki's avatar
Masaki Kozuki committed
300
  )
301
302
303
304
305
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensors are contiguous
  //
306
  for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
307
308
309
    SharedMemory<U> shared;
    U* buf = shared.getPointer();
    U mu,sigma2;
310
    cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf, GPU_WARP_SIZE);
311
    const T* lvals = vals + i1*n2;
Masaki Kozuki's avatar
Masaki Kozuki committed
312
    V* ovals = output_vals + i1*n2;
313
314
315
316
317
318
    U c_invvar = rsqrt(sigma2 + epsilon);
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    if (gamma != NULL && beta != NULL) {
      for (int i = thrx;  i < n2;  i+=numx) {
        U curr = static_cast<U>(lvals[i]);
Masaki Kozuki's avatar
Masaki Kozuki committed
319
        ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
320
321
322
323
      }
    } else {
      for (int i = thrx;  i < n2;  i+=numx) {
        U curr = static_cast<U>(lvals[i]);
Masaki Kozuki's avatar
Masaki Kozuki committed
324
        ovals[i] = static_cast<V>(c_invvar * (curr - mu));
325
326
327
328
329
330
      }
    }
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      mean[i1] = mu;
      invvar[i1] = c_invvar;
    }
eqy's avatar
eqy committed
331
    __syncthreads();
332
333
334
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
335
336
337
338
339
340
341
342
343
344
template<typename T, typename U, typename V=T> __global__
void cuApplyLayerNorm(
  V* __restrict__ output_vals,
  U* __restrict__ mean,
  U* __restrict__ invvar,
  const T* __restrict__ vals,
  const int n1,
  const int n2,
  const U epsilon,
  const V* __restrict__ gamma,
345
346
  const V* __restrict__ beta,
  const int warp_size)
Masaki Kozuki's avatar
Masaki Kozuki committed
347
{
348
  cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size);
Masaki Kozuki's avatar
Masaki Kozuki committed
349
350
351
}

template<typename T, typename U, typename V> __device__
352
353
354
355
356
357
358
359
360
void cuLoadWriteStridedInputs(
    const int i1_block,
    const int thr_load_row_off,
    const int thr_load_col_off,
    const int i2_off,
    const int row_stride,
    U* warp_buf1,
    U* warp_buf2,
    const T* input,
Masaki Kozuki's avatar
Masaki Kozuki committed
361
    const V* dout,
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    const int i1_end,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar
    )
{
  int i1 = i1_block+thr_load_row_off;
  if (i1 < i1_end) {
    U curr_mean = mean[i1];
    U curr_invvar = invvar[i1];
    for (int k = 0;  k < blockDim.y;  ++k) {
      int i2 = i2_off + k;
      int load_idx = i1*n2+i2;
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      if (i2<n2) {
        U curr_input = static_cast<U>(input[load_idx]);
Masaki Kozuki's avatar
Masaki Kozuki committed
378
379
380
        U curr_dout = static_cast<U>(dout[load_idx]);
        warp_buf1[write_idx] = curr_dout;
        warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
381
382
383
384
385
386
387
388
389
390
391
392
393
      } else {
        warp_buf1[write_idx] = U(0);
        warp_buf2[write_idx] = U(0);
      }
    }
  } else {
    for (int k = 0;  k < blockDim.y;  ++k) {
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      warp_buf1[write_idx] = U(0);
      warp_buf2[write_idx] = U(0);
    }
  }
}
Masaki Kozuki's avatar
Masaki Kozuki committed
394
template<typename T, typename U, typename V> __device__
395
396
397
398
399
400
401
402
403
void cuLoadAddStridedInputs(
    const int i1_block,
    const int thr_load_row_off,
    const int thr_load_col_off,
    const int i2_off,
    const int row_stride,
    U* warp_buf1,
    U* warp_buf2,
    const T* input,
Masaki Kozuki's avatar
Masaki Kozuki committed
404
    const V* dout,
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    const int i1_end,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar
    )
{
  int i1 = i1_block+thr_load_row_off;
  if (i1 < i1_end) {
    U curr_mean = mean[i1];
    U curr_invvar = invvar[i1];
    for (int k = 0;  k < blockDim.y;  ++k) {
      int i2 = i2_off + k;
      int load_idx = i1*n2+i2;
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      if (i2<n2) {
        U curr_input = static_cast<U>(input[load_idx]);
Masaki Kozuki's avatar
Masaki Kozuki committed
421
422
423
        U curr_dout = static_cast<U>(dout[load_idx]);
        warp_buf1[write_idx] += curr_dout;
        warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
424
425
426
427
428
      }
    }
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
429
template<typename T, typename U, typename V> __global__
430
void cuComputePartGradGammaBeta(
Masaki Kozuki's avatar
Masaki Kozuki committed
431
    const V* __restrict__ dout,
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    const T* __restrict__ input,
    const int n1,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar,
    U epsilon,
    U* part_grad_gamma,
    U* part_grad_beta)
{
    const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
    const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
    const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
    const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
    const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
    const int row_stride = blockDim.x+1;
    const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
    const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
    const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
    SharedMemory<U> shared;
    U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
    U* warp_buf1 = (U*)buf;
    U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
    // compute partial sums from strided inputs
    // do this to increase number of loads in flight
    cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
    for (int i1_block = i1_beg+blockDim.y*blockDim.y;  i1_block < i1_end;  i1_block+=blockDim.y*blockDim.y) {
      cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
    }
    __syncthreads();
    // inter-warp reductions
    // sum within each warp
    U acc1 = U(0);
    U acc2 = U(0);
    for (int k = 0;  k < blockDim.y;  ++k) {
      int row1 = threadIdx.y + k*blockDim.y;
      int idx1 = row1*row_stride + threadIdx.x;
      acc1 += warp_buf1[idx1];
      acc2 += warp_buf2[idx1];
    }
    warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
    warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
    __syncthreads();
    // sum all warps
    for (int offset = blockDim.y/2;  offset > 1;  offset /= 2) {
      if (threadIdx.y < offset) {
        int row1 = threadIdx.y;
Masaki Kozuki's avatar
Masaki Kozuki committed
478
479
480
481
482
        int row2 = threadIdx.y + offset;
        int idx1 = row1*row_stride + threadIdx.x;
        int idx2 = row2*row_stride + threadIdx.x;
        warp_buf1[idx1] += warp_buf1[idx2];
        warp_buf2[idx1] += warp_buf2[idx2];
483
484
485
486
487
488
489
490
491
492
493
494
495
496
      }
      __syncthreads();
    }
    int i2 = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadIdx.y == 0 && i2 < n2) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + 1;
      int idx1 = row1*row_stride + threadIdx.x;
      int idx2 = row2*row_stride + threadIdx.x;
      part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
      part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
    }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
497
template<typename U, typename V> __global__
498
499
500
501
502
503
void cuComputeGradGammaBeta(
    const U* part_grad_gamma,
    const U* part_grad_beta,
    const int part_size,
    const int n1,
    const int n2,
Masaki Kozuki's avatar
Masaki Kozuki committed
504
505
    V* grad_gamma,
    V* grad_beta)
506
507
508
{
    // sum partial gradients for gamma and beta
    SharedMemory<U> shared;
Masaki Kozuki's avatar
Masaki Kozuki committed
509
    U* buf = shared.getPointer();
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    int i2 = blockIdx.x * blockDim.x + threadIdx.x;
    if (i2 < n2) {
      // each warp does sequential reductions until reduced part_size is num_warps
      int num_warp_reductions = part_size / blockDim.y;
      U sum_gamma = U(0);
      U sum_beta = U(0);
      const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
      const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
      for (int warp_offset = 0;  warp_offset < num_warp_reductions;  ++warp_offset) {
        sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
        sum_beta += part_grad_beta_ptr[warp_offset*n2];
      }
      // inter-warp reductions
      const int nbsize3 = blockDim.x * blockDim.y / 2;
      for (int offset = blockDim.y/2;  offset >= 1;  offset /= 2) {
        // top half write to shared memory
        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
          buf[write_idx] = sum_gamma;
          buf[write_idx+nbsize3] = sum_beta;
        }
        __syncthreads();
        // bottom half sums
        if (threadIdx.y < offset) {
          const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
          sum_gamma += buf[read_idx];
          sum_beta += buf[read_idx+nbsize3];
        }
        __syncthreads();
      }
      // write out fully summed gradients
      if (threadIdx.y == 0) {
        grad_gamma[i2] = sum_gamma;
        grad_beta[i2] = sum_beta;
      }
    }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
548
template<typename T, typename U, typename V> __global__
549
void cuComputeGradInput(
Masaki Kozuki's avatar
Masaki Kozuki committed
550
    const V* __restrict__ dout,
551
552
553
554
555
556
    const T* __restrict__ input,
    const int n1,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar,
    U epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
557
    const V* gamma,
558
559
    T* grad_input)
{
560
  for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
561
562
563
564
565
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
    const U c_mean = mean[i1];
    const U c_invvar = invvar[i1];
    const T* k_input = input + i1*n2;
Masaki Kozuki's avatar
Masaki Kozuki committed
566
    const V* k_dout = dout + i1*n2;
567
568
569
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    if (gamma != NULL) {
570
      #ifndef __HIP_PLATFORM_HCC__
571
      int l = 4*thrx;
572
573
      for (;  l+3 < n2;  l+=4*numx) {           
	for (int k = 0;  k < 4;  ++k) {
574
575
576
577
578
579
580
581
582
583
584
585
          const U c_h = static_cast<U>(k_input[l+k]);
          const U c_loss = static_cast<U>(k_dout[l+k]);
          sum_loss1 += c_loss * gamma[l+k];
          sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
        }
      }
      for (;  l < n2;  ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss * gamma[l];
        sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
      }
586
587
588
589
590
591
592
593
594
595
596
      #else
      // Optimization for ROCm MI100
      for( int l = 0; l < n2 ; l += numx) {
        int idx = l + thrx;
        const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0));
        const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
        const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
        sum_loss1 += c_loss * gamma_idx;
        sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
      }
      #endif
597
    } else {
598
      #ifndef __HIP_PLATFORM_HCC__
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
      int l = 4*thrx;
      for (;  l+3 < n2;  l+=4*numx) {
        for (int k = 0;  k < 4;  ++k) {
          const U c_h = static_cast<U>(k_input[l+k]);
          const U c_loss = static_cast<U>(k_dout[l+k]);
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        }
      }
      for (;  l < n2;  ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
614
615
616
617
618
619
620
621
622
      #else
      for( int l = 0; l < n2 ; l += numx) {
        int idx = l + thrx;
        const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
        const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
      #endif
623
624
    }
    // intra-warp reductions
625
626
627
    for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
      sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
      sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
628
629
630
631
    }
    // inter-warp reductions
    if (blockDim.y > 1) {
      SharedMemory<U> shared;
Masaki Kozuki's avatar
Masaki Kozuki committed
632
      U* buf = shared.getPointer();
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
          buf[2*wrt_i] = sum_loss1;
          buf[2*wrt_i+1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
          sum_loss1 += buf[2*read_i];
          sum_loss2 += buf[2*read_i+1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
        buf[2*threadIdx.x] = sum_loss1;
        buf[2*threadIdx.x+1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y !=0) {
        sum_loss1 = buf[2*threadIdx.x];
        sum_loss2 = buf[2*threadIdx.x+1];
Masaki Kozuki's avatar
Masaki Kozuki committed
657
      }
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T* k_grad_input = grad_input + i1*n2;
    if (gamma != NULL) {
      for (int l = thrx;  l < n2;  l+=numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss * gamma[l];
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx;  l < n2;  l+=numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
        f_grad_input -= sum_loss1;
        f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    }
eqy's avatar
eqy committed
684
685
    // prevent race where buf is written again before reads are done
    __syncthreads();
686
687
688
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
689
template<typename T, typename U, typename V=T>
690
void HostApplyLayerNorm(
Masaki Kozuki's avatar
Masaki Kozuki committed
691
    V* output,
692
693
694
695
696
697
    U* mean,
    U* invvar,
    const T* input,
    int n1,
    int n2,
    double epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
698
699
    const V* gamma,
    const V* beta
700
701
702
    )
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();
703
704
705
706
707
708
709
    const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
    dim3 threads(warp_size ,4, 1);  // MI100 wavefront/warp = 64
    #ifdef __HIP_PLATFORM_HCC__
    // Optimization for ROCm MI100
    threads.y = 1;
    #endif
    
710
711
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
Masaki Kozuki's avatar
Masaki Kozuki committed
712
713
714
    int nshared =
        threads.y > 1 ?
	    threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
715
716
	    0;
    cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
717
      output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
718
719
720
721
722
723
724
725
726
}

void cuda_layer_norm(
    at::Tensor* output,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
727
    #ifdef VERSION_GE_1_1
728
    at::IntArrayRef normalized_shape,
729
730
731
    #else
    at::IntList normalized_shape,
    #endif
732
733
734
735
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon)
{
736
    using namespace at;
Masaki Kozuki's avatar
Masaki Kozuki committed
737
738
739
740
741
742
743
744
745
746
747
748
    DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
        input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel",
        using accscalar_t = at::acc_type<scalar_t_in, true>;
        HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(
          output->DATA_PTR<scalar_t_out>(),
	      mean->DATA_PTR<accscalar_t>(),
          invvar->DATA_PTR<accscalar_t>(),
          input->DATA_PTR<scalar_t_in>(),
          n1,n2,
          epsilon,
          gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
          beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
749
      )
750
751
}

Masaki Kozuki's avatar
Masaki Kozuki committed
752
template<typename T, typename U=float, typename V=T>
753
void HostLayerNormGradient(
Masaki Kozuki's avatar
Masaki Kozuki committed
754
    const V* dout,
755
756
757
758
759
    const U* mean,
    const U* invvar,
    at::Tensor* input,
    int n1,
    int n2,
Masaki Kozuki's avatar
Masaki Kozuki committed
760
761
    const V* gamma,
    const V* beta,
762
763
    double epsilon,
    T* grad_input,
Masaki Kozuki's avatar
Masaki Kozuki committed
764
765
    V* grad_gamma,
    V* grad_beta
766
767
768
    )
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();
769
770
    const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
    
771
772
    if (gamma != NULL && beta != NULL) {
      // compute grad_gamma(j) and grad_beta(j)
773
774
775
      const int part_size = warp_size;
      const dim3 threads2(warp_size, 4, 1);
      const dim3 blocks2((n2+threads2.x-1) / threads2.x,part_size, 1);
776
777
778
      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
      const int nshared2_b = threads2.x * threads2.y * sizeof(U);
      const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
Masaki Kozuki's avatar
Masaki Kozuki committed
779
780
781
782
783
784
785
      // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
      // the `cuda_layer_norm_gradient` doesn't support double.
      const auto part_grad_dtype =
        (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ?
        at::ScalarType::Float :
        input->scalar_type();
      at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
786
787
788
      at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
      cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
		      dout,
mcarilli's avatar
mcarilli committed
789
		      input->DATA_PTR<T>(),
790
791
792
793
		      n1,n2,
		      mean,
		      invvar,
		      U(epsilon),
mcarilli's avatar
mcarilli committed
794
795
		      part_grad_gamma.DATA_PTR<U>(),
		      part_grad_beta.DATA_PTR<U>());
796

797
      const dim3 threads3(warp_size, 8, 1);
798
799
800
      const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
      const int nshared3 = threads3.x * threads3.y * sizeof(U);
      cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
mcarilli's avatar
mcarilli committed
801
802
		      part_grad_gamma.DATA_PTR<U>(),
		      part_grad_beta.DATA_PTR<U>(),
803
804
805
806
807
808
809
		      part_size,
		      n1,n2,
		      grad_gamma,
		      grad_beta);
    }

    // compute grad_input
810
811
812
    // https://github.com/microsoft/onnxruntime/pull/7682/files#diff-f9eace25e62b646410b067f96cd930c7fe843326dca1e8d383631ca27f1a8d00R540
    // https://github.com/amathews-amd/onnxruntime/blob/80c0555c2bc17fb109190e2082cd3fda0a37984c/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu#L541
    
813
814
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
815
816
817
818
819
    dim3 threads1(warp_size,4,1);  // MI100 wavefront/warp = 64
    #ifdef __HIP_PLATFORM_HCC__
    // Optimization for ROCm MI100
    threads1.y = 2;
    #endif
820
821
822
823
824
825
    int nshared =
	    threads1.y > 1 ?
	    threads1.y*threads1.x*sizeof(U) :
	    0;
    cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
            dout,
mcarilli's avatar
mcarilli committed
826
            input->DATA_PTR<T>(),
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
            n1,n2,
            mean,
            invvar,
            U(epsilon),
            gamma,
            grad_input);
}

void cuda_layer_norm_gradient(
    at::Tensor* dout,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
842
    #ifdef VERSION_GE_1_1
843
    at::IntArrayRef normalized_shape,
844
845
846
    #else
    at::IntList normalized_shape,
    #endif
847
848
849
850
851
852
853
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon,
    at::Tensor* grad_input,
    at::Tensor* grad_gamma,
    at::Tensor* grad_beta)
{
854
    using namespace at;
Masaki Kozuki's avatar
Masaki Kozuki committed
855
856
857
858
859
860
861
862
863
864
    // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
    DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
      input->scalar_type(), gamma == NULL ? input->scalar_type() :  gamma->scalar_type(), "cuComputeGradInput",
      using accscalar_t = at::acc_type<scalar_t_in, true>;
      HostLayerNormGradient(
        dout->DATA_PTR<scalar_t_out>(),
        mean->DATA_PTR<accscalar_t>(),
        invvar->DATA_PTR<accscalar_t>(),
        input,
        n1,n2,
ngimel's avatar
ngimel committed
865
866
            // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
            // if gamma Tensor is NULL on input.
Masaki Kozuki's avatar
Masaki Kozuki committed
867
868
869
870
871
872
873
        gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
        gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
        epsilon,
        grad_input->DATA_PTR<scalar_t_in>(),
        gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
        gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
    )
874
}
875