"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1a6a647e06592ba1157f620ec28efaf3c8b4509e"
Commit f9a026ad authored by maxiao's avatar maxiao
Browse files

fix fused_add_rms_norm bug

parent b80ae5e9
...@@ -175,7 +175,7 @@ class RMSNorm(CustomOp): ...@@ -175,7 +175,7 @@ class RMSNorm(CustomOp):
self.weight.data, self.weight.data,
self.variance_epsilon, self.variance_epsilon,
) )
return output, residual_out return x, residual
except TypeError: except TypeError:
fused_add_rms_norm( fused_add_rms_norm(
output, output,
...@@ -185,7 +185,8 @@ class RMSNorm(CustomOp): ...@@ -185,7 +185,8 @@ class RMSNorm(CustomOp):
self.weight.data, self.weight.data,
self.variance_epsilon, self.variance_epsilon,
) )
return x, residual return output, residual_out
out = torch.empty_like(x) out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon) rms_norm(out, x, self.weight.data, self.variance_epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment