Unverified Commit 7669bf3d authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[Core] Fix bug when selecting tuned RMSNorm kernels (#983)



Fix typo when selecting tuned RMSNorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 3a9a4c83
...@@ -89,7 +89,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype ...@@ -89,7 +89,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype
if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) &&
is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && is_aligned(params.dgamma) && is_aligned(params.dgamma_part) &&
layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { BWD_TUNED_FUNCS.count(tuned_key) > 0) {
return BWD_TUNED_FUNCS.at(tuned_key); return BWD_TUNED_FUNCS.at(tuned_key);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment