Commit 8de7a1ce authored by yuguo's avatar yuguo
Browse files
parents caf2fbf2 daa15293
......@@ -88,8 +88,9 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None):
return rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm":
out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma
else:
return normalization_func(
*inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment