Commit 4a8e1a87 authored by Natalia Gimelshein's avatar Natalia Gimelshein
Browse files

fix fused layer norm for >65535 batch

parent 3d01e4a0
......@@ -53,6 +53,7 @@ void cuWelfordMuSigma2(
const T* __restrict__ vals,
const int n1,
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
......@@ -66,7 +67,6 @@ void cuWelfordMuSigma2(
U count = U(0);
mu= U(0);
sigma2 = U(0);
int i1 = blockIdx.y;
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
......@@ -137,6 +137,7 @@ void cuWelfordMuSigma2(
const at::Half* __restrict__ vals,
const int n1,
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
......@@ -150,7 +151,6 @@ void cuWelfordMuSigma2(
float count = 0.0f;
mu= float(0);
sigma2 = float(0);
int i1 = blockIdx.y;
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
......@@ -293,12 +293,11 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
int i1 = blockIdx.y;
if (i1 < n1) {
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,mu,sigma2,buf);
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2;
T* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
......@@ -532,8 +531,7 @@ void cuComputeGradInput(
const T* gamma,
T* grad_input)
{
int i1 = blockIdx.y;
if (i1 < n1) {
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
......@@ -653,7 +651,8 @@ void HostApplyLayerNorm(
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const dim3 blocks(1,n1,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
......@@ -750,8 +749,9 @@ void HostLayerNormGradient(
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
const dim3 blocks1(1,n1,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment