Unverified Commit 5221832e authored by Ying Xiong's avatar Ying Xiong Committed by GitHub
Browse files

Fix wrong idx bug in invertible LayerNormBackward1 (#692)



* fix wrong idx bug in invertible LayerNormBackward1

this index bug cause wrong scale grad

* fix unexpected deletion

* fix idx for LayerNormBackward1_fused_add

* move pos defination in LayerNormBackward1 kernels

* fix format error
Co-authored-by: default avatarReza Yazdani <reyazda@microsoft.com>
parent 852c524a
...@@ -624,9 +624,8 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad, ...@@ -624,9 +624,8 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
int offset = threadIdx.y * width + idx; int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM; int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y; float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float betta_reg = (invertible ? (float)betta[pos] : 0.0f); float gamma_reg = (float)gamma[idx];
float gamma_reg = (float)gamma[pos];
// Loop across matrix height // Loop across matrix height
float betta_tmp = 0; float betta_tmp = 0;
...@@ -660,6 +659,7 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad, ...@@ -660,6 +659,7 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
} }
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1; betta_grad[pos] = s1;
gamma_grad[pos] = s2; gamma_grad[pos] = s2;
} }
...@@ -1368,9 +1368,8 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, ...@@ -1368,9 +1368,8 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
int offset = threadIdx.y * width + idx; int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM; int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y; float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float betta_reg = (invertible ? (float)betta[pos] : 0.0f); float gamma_reg = (float)gamma[idx];
float gamma_reg = (float)gamma[pos];
// Loop across matrix height // Loop across matrix height
float betta_tmp = 0; float betta_tmp = 0;
...@@ -1404,6 +1403,7 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, ...@@ -1404,6 +1403,7 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
} }
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1; betta_grad[pos] = s1;
gamma_grad[pos] = s2; gamma_grad[pos] = s2;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment