Unverified Commit 7eed38aa authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

Fix LayerNorm op on ROCm (#36)

* fix warp size in WARP_SHFL* in layernorm

* enable fused_layer_norm tests on ROCm
parent e9c43d67
...@@ -88,9 +88,9 @@ void cuWelfordMuSigma2( ...@@ -88,9 +88,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB); U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB); U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB); U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
...@@ -126,8 +126,8 @@ void cuWelfordMuSigma2( ...@@ -126,8 +126,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2); sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0); sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
} }
} }
} }
...@@ -183,9 +183,9 @@ void cuWelfordMuSigma2( ...@@ -183,9 +183,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions // intra-warp reductions
for (int l = 0; l <= 4; ++l) { for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31; int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB); float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB); float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB); float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
} }
// threadIdx.x == 0 has correct values for each warp // threadIdx.x == 0 has correct values for each warp
...@@ -221,8 +221,8 @@ void cuWelfordMuSigma2( ...@@ -221,8 +221,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2); sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2 // don't care about final value of count, we know count == n2
} else { } else {
mu = WARP_SHFL(mu, 0); mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0); sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
} }
} }
} }
...@@ -581,8 +581,8 @@ void cuComputeGradInput( ...@@ -581,8 +581,8 @@ void cuComputeGradInput(
} }
// intra-warp reductions // intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) { for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
} }
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
......
...@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm ...@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"] test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]
ROCM_BLACKLIST = [ ROCM_BLACKLIST = [
'run_fused_layer_norm',
'run_pyprof_nvtx', 'run_pyprof_nvtx',
'run_pyprof_data', 'run_pyprof_data',
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment