Commit 4a8e1a87 authored by Natalia Gimelshein's avatar Natalia Gimelshein
Browse files

fix fused layer norm for >65535 batch

parent 3d01e4a0
...@@ -53,6 +53,7 @@ void cuWelfordMuSigma2( ...@@ -53,6 +53,7 @@ void cuWelfordMuSigma2(
const T* __restrict__ vals, const T* __restrict__ vals,
const int n1, const int n1,
const int n2, const int n2,
const int i1,
U& mu, U& mu,
U& sigma2, U& sigma2,
U* buf) U* buf)
...@@ -66,7 +67,6 @@ void cuWelfordMuSigma2( ...@@ -66,7 +67,6 @@ void cuWelfordMuSigma2(
U count = U(0); U count = U(0);
mu= U(0); mu= U(0);
sigma2 = U(0); sigma2 = U(0);
int i1 = blockIdx.y;
if (i1 < n1) { if (i1 < n1) {
// one warp normalizes one n1 index, // one warp normalizes one n1 index,
// synchronization is implicit // synchronization is implicit
...@@ -137,6 +137,7 @@ void cuWelfordMuSigma2( ...@@ -137,6 +137,7 @@ void cuWelfordMuSigma2(
const at::Half* __restrict__ vals, const at::Half* __restrict__ vals,
const int n1, const int n1,
const int n2, const int n2,
const int i1,
float& mu, float& mu,
float& sigma2, float& sigma2,
float* buf) float* buf)
...@@ -150,7 +151,6 @@ void cuWelfordMuSigma2( ...@@ -150,7 +151,6 @@ void cuWelfordMuSigma2(
float count = 0.0f; float count = 0.0f;
mu= float(0); mu= float(0);
sigma2 = float(0); sigma2 = float(0);
int i1 = blockIdx.y;
if (i1 < n1) { if (i1 < n1) {
// one warp normalizes one n1 index, // one warp normalizes one n1 index,
// synchronization is implicit // synchronization is implicit
...@@ -293,12 +293,11 @@ void cuApplyLayerNorm( ...@@ -293,12 +293,11 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
// 2) Tensors are contiguous // 2) Tensors are contiguous
// //
int i1 = blockIdx.y; for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
if (i1 < n1) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
U mu,sigma2; U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,mu,sigma2,buf); cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2; const T* lvals = vals + i1*n2;
T* ovals = output_vals + i1*n2; T* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon); U c_invvar = rsqrt(sigma2 + epsilon);
...@@ -532,8 +531,7 @@ void cuComputeGradInput( ...@@ -532,8 +531,7 @@ void cuComputeGradInput(
const T* gamma, const T* gamma,
T* grad_input) T* grad_input)
{ {
int i1 = blockIdx.y; for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
if (i1 < n1) {
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
const U c_mean = mean[i1]; const U c_mean = mean[i1];
...@@ -653,7 +651,8 @@ void HostApplyLayerNorm( ...@@ -653,7 +651,8 @@ void HostApplyLayerNorm(
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1); const dim3 threads(32,4,1);
const dim3 blocks(1,n1,1); const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = int nshared =
threads.y > 1 ? threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
...@@ -750,8 +749,9 @@ void HostLayerNormGradient( ...@@ -750,8 +749,9 @@ void HostLayerNormGradient(
} }
// compute grad_input // compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1); const dim3 threads1(32,4,1);
const dim3 blocks1(1,n1,1);
int nshared = int nshared =
threads1.y > 1 ? threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) : threads1.y*threads1.x*sizeof(U) :
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment