Commit f9a026ad authored by maxiao's avatar maxiao
Browse files

fix fused_add_rms_norm bug

parent b80ae5e9
......@@ -175,7 +175,7 @@ class RMSNorm(CustomOp):
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
return x, residual
except TypeError:
fused_add_rms_norm(
output,
......@@ -185,7 +185,8 @@ class RMSNorm(CustomOp):
self.weight.data,
self.variance_epsilon,
)
return x, residual
return output, residual_out
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment