Commit daa15293 authored by panning's avatar panning Committed by wenjh
Browse files

[DAS-RMSNorm] TE 2.3 returns 3 values of rmsnorm



API `rmsnorm_forward` of python returns 3 values rather than 2 from V2.3
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 0b0a70a5
......@@ -88,8 +88,9 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None):
return rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm":
out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma
else:
return normalization_func(
*inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment