layer_norm_cuda_kernel.cu 36.8 KB
Newer Older
1
2
3
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
4
#include "ATen/cuda/DeviceUtils.cuh"
5
6
7
8

#include <cuda.h>
#include <cuda_runtime.h>

9
10
#include "type_shim.h"

11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
template<typename U> __device__
void cuWelfordOnlineSum(
  const U curr,
  U& mu,
  U& sigma2,
  U& count)
{
  count = count + U(1);
  U delta = curr - mu;
  U lmean = mu + delta / count;
  mu = lmean;
  U delta2 = curr - lmean;
  sigma2 = sigma2 + delta * delta2;
}

template<typename U> __device__
void cuChanOnlineSum(
  const U muB,
  const U sigma2B,
  const U countB,
  U& mu,
  U& sigma2,
  U& count)
{
  U delta = muB - mu;
  U nA = count;
  U nB = countB;
  count = count + countB;
  U nX = count;
  if (nX > U(0)) {
    nA = nA / nX;
    nB = nB / nX;
    mu = nA*mu + nB*muB;
    sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
  } else {
    mu = U(0);
    sigma2 = U(0);
  }
}

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
template<typename U> __device__
void cuRMSOnlineSum(
  const U curr,
  U& sigma2)
{
  sigma2 = sigma2 + curr * curr;
}

template<typename U> __device__
void cuChanRMSOnlineSum(
  const U sigma2B,
  U& sigma2)
{
  sigma2 = sigma2 + sigma2B;
}


69
70
71
72
73
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
  const T* __restrict__ vals,
  const int n1,
  const int n2,
74
  const int i1,
75
76
  U& mu,
  U& sigma2,
77
78
  U* buf,
  const int GPU_WARP_SIZE)
79
  bool rms_only)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensor is contiguous
  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  //
  // compute variance and mean over n2
  U count = U(0);
  mu= U(0);
  sigma2 = U(0);
  if (i1 < n1) {
    // one warp normalizes one n1 index,
    // synchronization is implicit
    // initialize with standard Welford algorithm
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    const T* lvals = vals + i1*n2;
    int l = 4*thrx;
    for (;  l+3 < n2;  l+=4*numx) {
      for (int k = 0;  k < 4;  ++k) {
        U curr = static_cast<U>(lvals[l+k]);
101
102
103
104
105
        if (!rms_only) {
          cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
        } else {
          cuRMSOnlineSum<U>(curr, sigma2);
        }
106
107
108
109
      }
    }
    for (;  l < n2;  ++l) {
      U curr = static_cast<U>(lvals[l]);
110
111
112
113
114
      if (!rms_only) {
        cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
      } else {
       cuRMSOnlineSum<U>(curr, sigma2);
      }
115
116
    }
    // intra-warp reductions
117
    #pragma unroll
118
    for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {  
119
      U sigma2B = WARP_SHFL_DOWN(sigma2, stride);
120
121
122
123
124
125
126
      if (!rms_only) {
        U muB = WARP_SHFL_DOWN(mu, stride);
        U countB = WARP_SHFL_DOWN(count, stride);  
        cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
      } else {
        cuChanRMSOnlineSum<U>(sigma2B, sigma2);
      }
127
128
129
130
131
132
133
134
135
136
    }
    // threadIdx.x == 0 has correct values for each warp
    // inter-warp reductions
    if (blockDim.y > 1) {
      U* ubuf = (U*)buf;
      U* ibuf = (U*)(ubuf + blockDim.y);
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_y = threadIdx.y - offset;
137
138
139
140
          if (!rms_only) {
            ubuf[2*wrt_y] = mu;
            ibuf[wrt_y] = count;
          }
141
142
143
144
145
146
          ubuf[2*wrt_y+1] = sigma2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.x == 0 && threadIdx.y < offset) {
          U sigma2B = ubuf[2*threadIdx.y+1];
147
148
149
150
151
152
153
          if (!rms_only) {
            U muB = ubuf[2*threadIdx.y];
            U countB = ibuf[threadIdx.y];
            cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
          } else {
            cuChanRMSOnlineSum<U>(sigma2B,sigma2);
          }
154
155
156
157
158
        }
        __syncthreads();
      }
      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
      if (threadIdx.x == 0 && threadIdx.y == 0) {
159
160
161
        if (!rms_only) {
          ubuf[0] = mu;
        }
162
163
164
        ubuf[1] = sigma2;
      }
      __syncthreads();
165
166
167
      if (!rms_only) {
        mu = ubuf[0];
      }
168
169
170
      sigma2 = ubuf[1]/U(n2);
      // don't care about final value of count, we know count == n2
    } else {
171
172
173
174
      if (!rms_only) {
        mu = WARP_SHFL(mu, 0);
      }
      sigma2 = WARP_SHFL(sigma2/U(n2), 0);
175
176
177
178
179
180
181
182
183
    }
  }
}

template<> __device__
void cuWelfordMuSigma2(
  const at::Half* __restrict__ vals,
  const int n1,
  const int n2,
184
  const int i1,
185
186
  float& mu,
  float& sigma2,
187
188
  float* buf,
  const int GPU_WARP_SIZE)
189
  bool rms_only)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensor is contiguous
  // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  //
  // compute variance and mean over n2
  float count = 0.0f;
  mu= float(0);
  sigma2 = float(0);
  if (i1 < n1) {
    // one warp normalizes one n1 index,
    // synchronization is implicit
    // initialize with standard Welford algorithm
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    const at::Half* lvals = vals + i1*n2;
    int l = 8*thrx;
    if ((((size_t)lvals)&3) != 0) {
      // 16 bit alignment
      // first thread consumes first point
      if (thrx == 0) {
        float curr = static_cast<float>(lvals[0]);
213
214
215
216
217
218
        if (!rms_only) {
          cuWelfordOnlineSum(curr,mu,sigma2,count);
        } else {
          cuRMSOnlineSum(curr, sigma2);
        }

219
220
221
222
223
224
225
      }
      ++l;
    }
    // at this point, lvals[l] are 32 bit aligned for all threads.
    for (;  l+7 < n2;  l+=8*numx) {
      for (int k = 0;  k < 8;  k+=2) {
        float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
226
227
228
229
230
231
232
        if (!rms_only) {
          cuWelfordOnlineSum(curr.x,mu,sigma2,count);
          cuWelfordOnlineSum(curr.y,mu,sigma2,count);
        } else {
          cuRMSOnlineSum(curr.x, sigma2);
          cuRMSOnlineSum(curr.y, sigma2);
        }
233
234
235
236
      }
    }
    for (;  l < n2;  ++l) {
      float curr = static_cast<float>(lvals[l]);
237
238
239
240
241
      if (!rms_only) {
        cuWelfordOnlineSum(curr,mu,sigma2,count);
      } else {
        cuRMSOnlineSum(curr, sigma2);
      }
242
243
    }
    // intra-warp reductions
244
    #pragma unroll
245
    for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
246
      float sigma2B = WARP_SHFL_DOWN(sigma2, stride);
247
248
249
250
251
252
253
      if (!rms_only) {
        float muB = WARP_SHFL_DOWN(mu, stride);
        float countB = WARP_SHFL_DOWN(count, stride);
        cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
      } else {
        cuChanRMSOnlineSum(sigma2B, sigma2);
      }
254
255
256
257
258
259
260
261
262
263
264
    }
    // threadIdx.x == 0 has correct values for each warp
    // inter-warp reductions
    if (blockDim.y > 1) {
      float* ubuf = (float*)buf;
      float* ibuf = (float*)(ubuf + blockDim.y);
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_y = threadIdx.y - offset;
          ubuf[2*wrt_y+1] = sigma2;
265
266
267
268
          if (!rms_only) {
            ubuf[2*wrt_y] = mu;
            ibuf[wrt_y] = count;
          }
269
270
271
272
273
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.x == 0 && threadIdx.y < offset) {
          float sigma2B = ubuf[2*threadIdx.y+1];
274
275
276
277
278
279
280
          if (!rms_only) {
            float muB = ubuf[2*threadIdx.y];
            float countB = ibuf[threadIdx.y];
            cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
          } else {
            cuChanRMSOnlineSum(sigma2B, sigma2);
          }
281
282
283
284
285
        }
        __syncthreads();
      }
      // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
      if (threadIdx.x == 0 && threadIdx.y == 0) {
286
287
288
        if (!rms_only) {
          ubuf[0] = mu;
        }
289
290
291
        ubuf[1] = sigma2;
      }
      __syncthreads();
292
293
294
      if (!rms_only) {
        mu = ubuf[0];
      }
295
296
297
      sigma2 = ubuf[1]/float(n2);
      // don't care about final value of count, we know count == n2
    } else {
298
299
300
301
      if (!rms_only) {
        mu = WARP_SHFL(mu, 0);
      }
      sigma2 = WARP_SHFL(sigma2/float(n2), 0);
302
303
304
305
306
307
308
    }
  }
}

template<typename U> U rsqrt(U v) {
  return U(1) / sqrt(v);
}
309
310
311
312
313
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
  return rsqrtf(v);
}
#else
314
315
316
template<> float rsqrt(float v) {
  return rsqrtf(v);
}
317
#endif
318
319
320
321
322
323
324
template<> double rsqrt(double v) {
  return rsqrt(v);
}

namespace {
// This is the un-specialized struct.  Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
Michael Carilli's avatar
Michael Carilli committed
325
326
327
328
329
330
331
332
333
334
335
336
//  template <typename T>
//  struct SharedMemory
//  {
//      // Ensure that we won't compile any un-specialized types
//      __device__ T *getPointer()
//      {
//          extern __device__ void error(void);
//          error();
//          return NULL;
//      }
//  };
// https://github.com/NVIDIA/apex/issues/246
337
template <typename T>
Michael Carilli's avatar
Michael Carilli committed
338
struct SharedMemory;
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

template <>
struct SharedMemory <float>
{
    __device__ float *getPointer()
    {
        extern __shared__ float s_float[];
        return s_float;
    }
};

template <>
struct SharedMemory <double>
{
    __device__ double *getPointer()
    {
        extern __shared__ double s_double[];
        return s_double;
    }
};
}

Masaki Kozuki's avatar
Masaki Kozuki committed
361
362
363
template<typename T, typename U, typename V> __device__
void cuApplyLayerNorm_(
  V* __restrict__ output_vals,
364
365
366
367
368
369
  U* __restrict__ mean,
  U* __restrict__ invvar,
  const T* __restrict__ vals,
  const int n1,
  const int n2,
  const U epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
370
  const V* __restrict__ gamma,
371
372
  const V* __restrict__ beta,
  const int GPU_WARP_SIZE
373
  bool rms_only
Masaki Kozuki's avatar
Masaki Kozuki committed
374
  )
375
376
377
378
379
{
  // Assumptions:
  // 1) blockDim.x == warpSize
  // 2) Tensors are contiguous
  //
380
  for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
381
382
383
    SharedMemory<U> shared;
    U* buf = shared.getPointer();
    U mu,sigma2;
384
    cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf, GPU_WARP_SIZE, rms_only);
385
    const T* lvals = vals + i1*n2;
Masaki Kozuki's avatar
Masaki Kozuki committed
386
    V* ovals = output_vals + i1*n2;
387
388
389
    U c_invvar = rsqrt(sigma2 + epsilon);
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
390
    if (gamma != NULL && (beta != NULL || rms_only)) {
391
392
      for (int i = thrx;  i < n2;  i+=numx) {
        U curr = static_cast<U>(lvals[i]);
393
394
395
396
397
398
        if (!rms_only) {
          ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
        } else {
          ovals[i] = gamma[i] * static_cast<V>(c_invvar * curr);
        }

399
400
401
402
      }
    } else {
      for (int i = thrx;  i < n2;  i+=numx) {
        U curr = static_cast<U>(lvals[i]);
403
404
405
406
407
        if (!rms_only) {
          ovals[i] = static_cast<V>(c_invvar * (curr - mu));
        } else {
          ovals[i] = static_cast<V>(c_invvar * curr);
        }
408
409
410
      }
    }
    if (threadIdx.x == 0 && threadIdx.y == 0) {
411
412
413
      if (!rms_only) {
        mean[i1] = mu;
      }
414
415
      invvar[i1] = c_invvar;
    }
eqy's avatar
eqy committed
416
    __syncthreads();
417
418
419
  }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
420
421
422
423
424
425
426
427
428
429
template<typename T, typename U, typename V=T> __global__
void cuApplyLayerNorm(
  V* __restrict__ output_vals,
  U* __restrict__ mean,
  U* __restrict__ invvar,
  const T* __restrict__ vals,
  const int n1,
  const int n2,
  const U epsilon,
  const V* __restrict__ gamma,
430
431
  const V* __restrict__ beta,
  const int warp_size)
Masaki Kozuki's avatar
Masaki Kozuki committed
432
{
433
  cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size, false);
Masaki Kozuki's avatar
Masaki Kozuki committed
434
435
436
}

template<typename T, typename U, typename V> __device__
437
438
439
440
441
442
443
444
445
void cuLoadWriteStridedInputs(
    const int i1_block,
    const int thr_load_row_off,
    const int thr_load_col_off,
    const int i2_off,
    const int row_stride,
    U* warp_buf1,
    U* warp_buf2,
    const T* input,
Masaki Kozuki's avatar
Masaki Kozuki committed
446
    const V* dout,
447
448
449
    const int i1_end,
    const int n2,
    const U* __restrict__ mean,
450
451
    const U* __restrict__ invvar,
    bool rms_only
452
453
454
455
    )
{
  int i1 = i1_block+thr_load_row_off;
  if (i1 < i1_end) {
456
457
458
459
    U curr_mean;
    if (!rms_only) {
      curr_mean = mean[i1];
    }
460
461
462
463
464
465
466
    U curr_invvar = invvar[i1];
    for (int k = 0;  k < blockDim.y;  ++k) {
      int i2 = i2_off + k;
      int load_idx = i1*n2+i2;
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      if (i2<n2) {
        U curr_input = static_cast<U>(input[load_idx]);
Masaki Kozuki's avatar
Masaki Kozuki committed
467
        U curr_dout = static_cast<U>(dout[load_idx]);
468
469
470
471
472
473
        if (!rms_only) {
          warp_buf1[write_idx] = curr_dout;
          warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
        } else {
          warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar;
        }
474
      } else {
475
476
477
        if (!rms_only) {
          warp_buf1[write_idx] = U(0);
        }
478
479
480
481
482
483
        warp_buf2[write_idx] = U(0);
      }
    }
  } else {
    for (int k = 0;  k < blockDim.y;  ++k) {
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
484
485
486
      if (!rms_only) {
        warp_buf1[write_idx] = U(0);
      }
487
488
489
490
      warp_buf2[write_idx] = U(0);
    }
  }
}
Masaki Kozuki's avatar
Masaki Kozuki committed
491
template<typename T, typename U, typename V> __device__
492
493
494
495
496
497
498
499
500
void cuLoadAddStridedInputs(
    const int i1_block,
    const int thr_load_row_off,
    const int thr_load_col_off,
    const int i2_off,
    const int row_stride,
    U* warp_buf1,
    U* warp_buf2,
    const T* input,
Masaki Kozuki's avatar
Masaki Kozuki committed
501
    const V* dout,
502
503
504
    const int i1_end,
    const int n2,
    const U* __restrict__ mean,
505
506
    const U* __restrict__ invvar,
    bool rms_only
507
508
509
510
    )
{
  int i1 = i1_block+thr_load_row_off;
  if (i1 < i1_end) {
511
512
513
514
    U curr_mean;
    if (!rms_only) {
      curr_mean = mean[i1];
    }
515
516
517
518
519
520
521
    U curr_invvar = invvar[i1];
    for (int k = 0;  k < blockDim.y;  ++k) {
      int i2 = i2_off + k;
      int load_idx = i1*n2+i2;
      int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
      if (i2<n2) {
        U curr_input = static_cast<U>(input[load_idx]);
Masaki Kozuki's avatar
Masaki Kozuki committed
522
        U curr_dout = static_cast<U>(dout[load_idx]);
523
524
525
526
527
528
        if (!rms_only) {
          warp_buf1[write_idx] += curr_dout;
          warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
        } else {
          warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar;
        }
529
530
531
532
533
      }
    }
  }
}

534

Masaki Kozuki's avatar
Masaki Kozuki committed
535
template<typename T, typename U, typename V> __global__
536
void cuComputePartGradGammaBeta(
Masaki Kozuki's avatar
Masaki Kozuki committed
537
    const V* __restrict__ dout,
538
539
540
541
542
543
544
    const T* __restrict__ input,
    const int n1,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar,
    U epsilon,
    U* part_grad_gamma,
545
546
    U* part_grad_beta,
    bool rms_only)
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
{
    const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
    const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
    const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
    const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
    const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
    const int row_stride = blockDim.x+1;
    const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
    const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
    const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
    SharedMemory<U> shared;
    U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
    U* warp_buf1 = (U*)buf;
    U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
    // compute partial sums from strided inputs
    // do this to increase number of loads in flight
563
    cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only);
564
    for (int i1_block = i1_beg+blockDim.y*blockDim.y;  i1_block < i1_end;  i1_block+=blockDim.y*blockDim.y) {
565
      cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only);
566
567
568
569
570
571
572
573
574
    }
    __syncthreads();
    // inter-warp reductions
    // sum within each warp
    U acc1 = U(0);
    U acc2 = U(0);
    for (int k = 0;  k < blockDim.y;  ++k) {
      int row1 = threadIdx.y + k*blockDim.y;
      int idx1 = row1*row_stride + threadIdx.x;
575
576
577
      if (!rms_only) {
        acc1 += warp_buf1[idx1];
      }
578
579
      acc2 += warp_buf2[idx1];
    }
580
581
582
    if (!rms_only) {
      warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
    }
583
584
585
586
587
588
    warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
    __syncthreads();
    // sum all warps
    for (int offset = blockDim.y/2;  offset > 1;  offset /= 2) {
      if (threadIdx.y < offset) {
        int row1 = threadIdx.y;
Masaki Kozuki's avatar
Masaki Kozuki committed
589
590
591
        int row2 = threadIdx.y + offset;
        int idx1 = row1*row_stride + threadIdx.x;
        int idx2 = row2*row_stride + threadIdx.x;
592
593
594
        if (!rms_only) {
          warp_buf1[idx1] += warp_buf1[idx2];
        }
Masaki Kozuki's avatar
Masaki Kozuki committed
595
        warp_buf2[idx1] += warp_buf2[idx2];
596
597
598
599
600
601
602
603
604
      }
      __syncthreads();
    }
    int i2 = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadIdx.y == 0 && i2 < n2) {
      int row1 = threadIdx.y;
      int row2 = threadIdx.y + 1;
      int idx1 = row1*row_stride + threadIdx.x;
      int idx2 = row2*row_stride + threadIdx.x;
605
606
607
      if (!rms_only) {
        part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
      }
608
609
610
611
      part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
    }
}

Masaki Kozuki's avatar
Masaki Kozuki committed
612
template<typename U, typename V> __global__
613
614
615
616
617
618
void cuComputeGradGammaBeta(
    const U* part_grad_gamma,
    const U* part_grad_beta,
    const int part_size,
    const int n1,
    const int n2,
Masaki Kozuki's avatar
Masaki Kozuki committed
619
    V* grad_gamma,
620
621
    V* grad_beta,
    bool rms_only)
622
623
624
{
    // sum partial gradients for gamma and beta
    SharedMemory<U> shared;
Masaki Kozuki's avatar
Masaki Kozuki committed
625
    U* buf = shared.getPointer();
626
627
628
629
630
631
632
633
634
635
    int i2 = blockIdx.x * blockDim.x + threadIdx.x;
    if (i2 < n2) {
      // each warp does sequential reductions until reduced part_size is num_warps
      int num_warp_reductions = part_size / blockDim.y;
      U sum_gamma = U(0);
      U sum_beta = U(0);
      const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
      const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
      for (int warp_offset = 0;  warp_offset < num_warp_reductions;  ++warp_offset) {
        sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
636
637
638
        if (!rms_only) {
          sum_beta += part_grad_beta_ptr[warp_offset*n2];
        }
639
640
641
642
643
644
645
646
      }
      // inter-warp reductions
      const int nbsize3 = blockDim.x * blockDim.y / 2;
      for (int offset = blockDim.y/2;  offset >= 1;  offset /= 2) {
        // top half write to shared memory
        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
          buf[write_idx] = sum_gamma;
647
648
649
          if (!rms_only) {
            buf[write_idx+nbsize3] = sum_beta;
          }
650
651
652
653
654
655
        }
        __syncthreads();
        // bottom half sums
        if (threadIdx.y < offset) {
          const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
          sum_gamma += buf[read_idx];
656
657
658
          if (!rms_only) {
            sum_beta += buf[read_idx+nbsize3];
          }
659
660
661
662
663
664
        }
        __syncthreads();
      }
      // write out fully summed gradients
      if (threadIdx.y == 0) {
        grad_gamma[i2] = sum_gamma;
665
666
667
        if (!rms_only) {
          grad_beta[i2] = sum_beta;
        }
668
669
670
671
      }
    }
}

672

Masaki Kozuki's avatar
Masaki Kozuki committed
673
template<typename T, typename U, typename V> __global__
674
void cuComputeGradInput(
Masaki Kozuki's avatar
Masaki Kozuki committed
675
    const V* __restrict__ dout,
676
677
678
679
680
681
    const T* __restrict__ input,
    const int n1,
    const int n2,
    const U* __restrict__ mean,
    const U* __restrict__ invvar,
    U epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
682
    const V* gamma,
683
684
    T* grad_input,
    bool rms_only)
685
{
686
  for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
687
688
    U sum_loss1 = U(0);
    U sum_loss2 = U(0);
689
690
691
692
    U c_mean;
    if (!rms_only) {
      c_mean = mean[i1];
    }
693
694
    const U c_invvar = invvar[i1];
    const T* k_input = input + i1*n2;
Masaki Kozuki's avatar
Masaki Kozuki committed
695
    const V* k_dout = dout + i1*n2;
696
697
698
    const int numx = blockDim.x * blockDim.y;
    const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
    if (gamma != NULL) {
699
      #ifndef __HIP_PLATFORM_HCC__
700
      int l = 4*thrx;
701
702
      for (;  l+3 < n2;  l+=4*numx) {           
	for (int k = 0;  k < 4;  ++k) {
703
704
          const U c_h = static_cast<U>(k_input[l+k]);
          const U c_loss = static_cast<U>(k_dout[l+k]);
705
706
707
708
709
710
          if (!rms_only) {
            sum_loss1 += c_loss * gamma[l+k];
            sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
          } else {
            sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar;
          }
711
712
713
714
715
        }
      }
      for (;  l < n2;  ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
716
717
718
719
720
721
722
        if (!rms_only) {
          sum_loss1 += c_loss * gamma[l];
          sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
        } else {
          sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar;
        }

723
      }
724
725
726
727
728
729
730
731
732
733
734
      #else
      // Optimization for ROCm MI100
      for( int l = 0; l < n2 ; l += numx) {
        int idx = l + thrx;
        const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0));
        const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
        const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
        sum_loss1 += c_loss * gamma_idx;
        sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
      }
      #endif
735
    } else {
736
      #ifndef __HIP_PLATFORM_HCC__
737
738
739
740
741
      int l = 4*thrx;
      for (;  l+3 < n2;  l+=4*numx) {
        for (int k = 0;  k < 4;  ++k) {
          const U c_h = static_cast<U>(k_input[l+k]);
          const U c_loss = static_cast<U>(k_dout[l+k]);
742
743
744
745
746
747
          if (!rms_only) {
            sum_loss1 += c_loss;
            sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
          } else {
            sum_loss2 += c_loss * (c_h) * c_invvar;
          }
748
749
750
751
752
        }
      }
      for (;  l < n2;  ++l) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
753
754
755
756
757
758
        if (!rms_only) {
          sum_loss1 += c_loss;
          sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
        } else {
          sum_loss2 += c_loss * (c_h) * c_invvar;
        }
759
      }
760
761
762
763
764
765
766
767
768
      #else
      for( int l = 0; l < n2 ; l += numx) {
        int idx = l + thrx;
        const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
        const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
        sum_loss1 += c_loss;
        sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
      }
      #endif
769
770
    }
    // intra-warp reductions
771
772
773
774
    for (int mask = blockDim.x/2;  mask > 0;  mask /= 2) {
      if (!rms_only) {
        sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
      }
775
      sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
776
777
778
779
    }
    // inter-warp reductions
    if (blockDim.y > 1) {
      SharedMemory<U> shared;
Masaki Kozuki's avatar
Masaki Kozuki committed
780
      U* buf = shared.getPointer();
781
782
783
784
      for (int offset = blockDim.y/2;  offset > 0;  offset /= 2) {
        // upper half of warps write to shared
        if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
          const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
785
786
787
          if (!rms_only) {
            buf[2*wrt_i] = sum_loss1;
          }
788
789
790
791
792
793
          buf[2*wrt_i+1] = sum_loss2;
        }
        __syncthreads();
        // lower half merges
        if (threadIdx.y < offset) {
          const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
794
795
796
          if (!rms_only) {
            sum_loss1 += buf[2*read_i];
          }
797
798
799
800
801
          sum_loss2 += buf[2*read_i+1];
        }
        __syncthreads();
      }
      if (threadIdx.y == 0) {
802
803
804
        if (!rms_only) {
          buf[2*threadIdx.x] = sum_loss1;
        }
805
806
807
808
        buf[2*threadIdx.x+1] = sum_loss2;
      }
      __syncthreads();
      if (threadIdx.y !=0) {
809
810
811
        if (!rms_only) {
          sum_loss1 = buf[2*threadIdx.x];
        }
812
        sum_loss2 = buf[2*threadIdx.x+1];
Masaki Kozuki's avatar
Masaki Kozuki committed
813
      }
814
815
816
817
818
819
820
821
822
823
    }
    // all threads now have the two sums over l
    U fH = (U)n2;
    U term1 = (U(1) / fH) * c_invvar;
    T* k_grad_input = grad_input + i1*n2;
    if (gamma != NULL) {
      for (int l = thrx;  l < n2;  l+=numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss * gamma[l];
824
825
826
827
828
829
        if (!rms_only) {
          f_grad_input -= sum_loss1;
          f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        } else {
          f_grad_input -= (c_h) * c_invvar * sum_loss2;
        }
830
831
832
833
834
835
836
837
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    } else {
      for (int l = thrx;  l < n2;  l+=numx) {
        const U c_h = static_cast<U>(k_input[l]);
        const U c_loss = static_cast<U>(k_dout[l]);
        U f_grad_input = fH * c_loss;
838
839
840
841
842
843
        if (!rms_only) {
          f_grad_input -= sum_loss1;
          f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
        } else {
          f_grad_input -= (c_h) * c_invvar * sum_loss2;
        }
844
845
846
847
        f_grad_input *= term1;
        k_grad_input[l] = static_cast<T>(f_grad_input);
      }
    }
eqy's avatar
eqy committed
848
849
    // prevent race where buf is written again before reads are done
    __syncthreads();
850
851
852
  }
}

853

Masaki Kozuki's avatar
Masaki Kozuki committed
854
template<typename T, typename U, typename V=T>
855
void HostApplyLayerNorm(
Masaki Kozuki's avatar
Masaki Kozuki committed
856
    V* output,
857
858
859
860
861
862
    U* mean,
    U* invvar,
    const T* input,
    int n1,
    int n2,
    double epsilon,
Masaki Kozuki's avatar
Masaki Kozuki committed
863
864
    const V* gamma,
    const V* beta
865
866
867
    )
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();
868
869
870
871
872
873
874
    const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
    dim3 threads(warp_size ,4, 1);  // MI100 wavefront/warp = 64
    #ifdef __HIP_PLATFORM_HCC__
    // Optimization for ROCm MI100
    threads.y = 1;
    #endif
    
875
876
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
Masaki Kozuki's avatar
Masaki Kozuki committed
877
878
    int nshared =
        threads.y > 1 ?
879
880
            threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
            0;
881
    cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
882
      output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
883
884
}

885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
template<typename T, typename U, typename V=T>
void HostApplyRMSNorm(
    V* output,
    U* invvar,
    const T* input,
    int n1,
    int n2,
    double epsilon,
    const V* gamma)
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    const dim3 threads(32,4,1);
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
    int nshared =
        threads.y > 1 ?
            threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
            0;
    cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(
      output, invvar, input, n1, n2, U(epsilon), gamma);
}

907
908
909
910
911
912
913
void cuda_layer_norm(
    at::Tensor* output,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
914
    #ifdef VERSION_GE_1_1
915
    at::IntArrayRef normalized_shape,
916
917
918
    #else
    at::IntList normalized_shape,
    #endif
919
920
921
922
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon)
{
923
    using namespace at;
Masaki Kozuki's avatar
Masaki Kozuki committed
924
925
926
927
928
    DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
        input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel",
        using accscalar_t = at::acc_type<scalar_t_in, true>;
        HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(
          output->DATA_PTR<scalar_t_out>(),
929
              mean->DATA_PTR<accscalar_t>(),
Masaki Kozuki's avatar
Masaki Kozuki committed
930
931
932
933
934
935
          invvar->DATA_PTR<accscalar_t>(),
          input->DATA_PTR<scalar_t_in>(),
          n1,n2,
          epsilon,
          gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
          beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
936
      )
937
938
}

939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
void cuda_rms_norm(
    at::Tensor* output,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor* gamma,
    double epsilon)
{
    using namespace at;
    DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
        input->scalar_type(), output->scalar_type(), "rms_norm_cuda_kernel",
        using accscalar_t = at::acc_type<scalar_t_in, true>;
        HostApplyRMSNorm<scalar_t_in, accscalar_t, scalar_t_out>(
          output->DATA_PTR<scalar_t_out>(),
          invvar->DATA_PTR<accscalar_t>(),
          input->DATA_PTR<scalar_t_in>(),
          n1,n2,
          epsilon,
          gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL);
      )
}


Masaki Kozuki's avatar
Masaki Kozuki committed
968
template<typename T, typename U=float, typename V=T>
969
void HostLayerNormGradient(
Masaki Kozuki's avatar
Masaki Kozuki committed
970
    const V* dout,
971
972
973
974
975
    const U* mean,
    const U* invvar,
    at::Tensor* input,
    int n1,
    int n2,
Masaki Kozuki's avatar
Masaki Kozuki committed
976
977
    const V* gamma,
    const V* beta,
978
979
    double epsilon,
    T* grad_input,
Masaki Kozuki's avatar
Masaki Kozuki committed
980
981
    V* grad_gamma,
    V* grad_beta
982
983
984
    )
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();
985
986
    const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
    
987
988
    if (gamma != NULL && beta != NULL) {
      // compute grad_gamma(j) and grad_beta(j)
989
      // Optimize layer normalization for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
990
991
992
      const int part_size = warp_size;
      const dim3 threads2(warp_size, 4, 1);
      const dim3 blocks2((n2+threads2.x-1) / threads2.x,part_size, 1);
993
994
995
      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
      const int nshared2_b = threads2.x * threads2.y * sizeof(U);
      const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
Masaki Kozuki's avatar
Masaki Kozuki committed
996
997
998
999
1000
1001
1002
      // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
      // the `cuda_layer_norm_gradient` doesn't support double.
      const auto part_grad_dtype =
        (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ?
        at::ScalarType::Float :
        input->scalar_type();
      at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
1003
1004
      at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
      cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
1005
1006
1007
1008
1009
1010
1011
1012
1013
                      dout,
                      input->DATA_PTR<T>(),
                      n1,n2,
                      mean,
                      invvar,
                      U(epsilon),
                      part_grad_gamma.DATA_PTR<U>(),
                      part_grad_beta.DATA_PTR<U>(),
                      false);
1014

1015
      const dim3 threads3(warp_size, 8, 1);
1016
1017
1018
      const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
      const int nshared3 = threads3.x * threads3.y * sizeof(U);
      cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
1019
1020
1021
1022
1023
1024
1025
                      part_grad_gamma.DATA_PTR<U>(),
                      part_grad_beta.DATA_PTR<U>(),
                      part_size,
                      n1,n2,
                      grad_gamma,
                      grad_beta,
                      false);
1026
1027
1028
    }

    // compute grad_input
1029
1030
1031
    // https://github.com/microsoft/onnxruntime/pull/7682/files#diff-f9eace25e62b646410b067f96cd930c7fe843326dca1e8d383631ca27f1a8d00R540
    // https://github.com/amathews-amd/onnxruntime/blob/80c0555c2bc17fb109190e2082cd3fda0a37984c/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu#L541
    
1032
1033
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
1034
1035
1036
1037
1038
    dim3 threads1(warp_size,4,1);  // MI100 wavefront/warp = 64
    #ifdef __HIP_PLATFORM_HCC__
    // Optimization for ROCm MI100
    threads1.y = 2;
    #endif
1039
    int nshared =
1040
1041
1042
            threads1.y > 1 ?
            threads1.y*threads1.x*sizeof(U) :
            0;
1043
1044
    cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
            dout,
mcarilli's avatar
mcarilli committed
1045
            input->DATA_PTR<T>(),
1046
1047
1048
1049
1050
            n1,n2,
            mean,
            invvar,
            U(epsilon),
            gamma,
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
            grad_input,
            false);
}
// TODO: Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U=float, typename V=T>
void HostRMSNormGradient(
    const V* dout,
    const U* invvar,
    at::Tensor* input,
    int n1,
    int n2,
    const V* gamma,
    double epsilon,
    T* grad_input,
    V* grad_gamma)
{
    auto stream = at::cuda::getCurrentCUDAStream().stream();

    if (gamma != NULL) {
      const int part_size = 16;
      const dim3 threads2(32,4,1);
      const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
      const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
      const int nshared2_b = threads2.x * threads2.y * sizeof(U);
      const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
      // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
      // the `cuda_layer_norm_gradient` doesn't support double.
      const auto part_grad_dtype =
        (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ?
        at::ScalarType::Float :
        input->scalar_type();
      at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
      cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
                      dout,
                      input->DATA_PTR<T>(),
                      n1,n2,
                      invvar, // unused
                      invvar,
                      U(epsilon),
                      part_grad_gamma.DATA_PTR<U>(),
                      part_grad_gamma.DATA_PTR<U>(), /* unused */
                      true);

      const dim3 threads3(32,8,1);
      const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
      const int nshared3 = threads3.x * threads3.y * sizeof(U);
      cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
                      part_grad_gamma.DATA_PTR<U>(),
                      part_grad_gamma.DATA_PTR<U>(), /* unused */
                      part_size,
                      n1,n2,
                      grad_gamma,
                      grad_gamma, /* unused */
                      true);
    }

    // compute grad_input
    const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
    const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
    const dim3 threads1(32,4,1);
    int nshared =
            threads1.y > 1 ?
            threads1.y*threads1.x*sizeof(U) :
            0;
    cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
            dout,
            input->DATA_PTR<T>(),
            n1,n2,
            invvar, /* unused */
            invvar,
            U(epsilon),
            gamma,
            grad_input,
            true);
1125
1126
1127
1128
1129
1130
1131
1132
1133
}

void cuda_layer_norm_gradient(
    at::Tensor* dout,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
1134
    #ifdef VERSION_GE_1_1
1135
    at::IntArrayRef normalized_shape,
1136
1137
1138
    #else
    at::IntList normalized_shape,
    #endif
1139
1140
1141
1142
1143
1144
1145
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon,
    at::Tensor* grad_input,
    at::Tensor* grad_gamma,
    at::Tensor* grad_beta)
{
1146
    using namespace at;
Masaki Kozuki's avatar
Masaki Kozuki committed
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
    DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
      input->scalar_type(), gamma == NULL ? input->scalar_type() :  gamma->scalar_type(), "cuComputeGradInput",
      using accscalar_t = at::acc_type<scalar_t_in, true>;
      HostLayerNormGradient(
        dout->DATA_PTR<scalar_t_out>(),
        mean->DATA_PTR<accscalar_t>(),
        invvar->DATA_PTR<accscalar_t>(),
        input,
        n1,n2,
ngimel's avatar
ngimel committed
1157
1158
            // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
            // if gamma Tensor is NULL on input.
Masaki Kozuki's avatar
Masaki Kozuki committed
1159
1160
1161
1162
1163
1164
1165
        gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
        gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
        epsilon,
        grad_input->DATA_PTR<scalar_t_in>(),
        gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
        gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
    )
1166
}
1167

1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
void cuda_rms_norm_gradient(
    at::Tensor* dout,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
    #ifdef VERSION_GE_1_1
    at::IntArrayRef normalized_shape,
    #else
    at::IntList normalized_shape,
    #endif
    at::Tensor* gamma,
    double epsilon,
    at::Tensor* grad_input,
    at::Tensor* grad_gamma)
{
    using namespace at;
    // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
    // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
    DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
      input->scalar_type(), gamma == NULL ? input->scalar_type() :  gamma->scalar_type(), "cuComputeGradInputRMS",
      using accscalar_t = at::acc_type<scalar_t_in, true>;
      HostRMSNormGradient(
        dout->DATA_PTR<scalar_t_out>(),
        invvar->DATA_PTR<accscalar_t>(),
        input,
        n1,n2,
            // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
            // if gamma Tensor is NULL on input.
        gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
        epsilon,
        grad_input->DATA_PTR<scalar_t_in>(),
        gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL);
    )
}