Commit 8de7a1ce authored by yuguo's avatar yuguo
Browse files
parents caf2fbf2 daa15293
...@@ -88,8 +88,9 @@ def apply_normalization( ...@@ -88,8 +88,9 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True) normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None): if enable_lightop and (ln_bias is None) and normalization == "RMSNorm":
return rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True) out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma
else: else:
return normalization_func( return normalization_func(
*inputs, *inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment