Unverified Commit 7eed38aa authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

Fix LayerNorm op on ROCm (#36)

* fix warp size in WARP_SHFL* in layernorm

* enable fused_layer_norm tests on ROCm
parent e9c43d67
......@@ -88,9 +88,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
......@@ -126,8 +126,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
}
}
}
......@@ -183,9 +183,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
......@@ -221,8 +221,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
}
}
}
......@@ -581,8 +581,8 @@ void cuComputeGradInput(
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
}
// inter-warp reductions
if (blockDim.y > 1) {
......
......@@ -6,7 +6,6 @@ from apex.testing.common_utils import TEST_WITH_ROCM, skipIfRocm
test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]
ROCM_BLACKLIST = [
'run_fused_layer_norm',
'run_pyprof_nvtx',
'run_pyprof_data',
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment