Commit daa15293 authored by panning's avatar panning Committed by wenjh
Browse files

[DAS-RMSNorm] TE 2.3 returns 3 values of rmsnorm



API `rmsnorm_forward` of python returns 3 values rather than 2 from V2.3
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 0b0a70a5
...@@ -88,8 +88,9 @@ def apply_normalization( ...@@ -88,8 +88,9 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True) normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None): if enable_lightop and (ln_bias is None) and normalization == "RMSNorm":
return rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True) out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma
else: else:
return normalization_func( return normalization_func(
*inputs, *inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment