Unverified Commit a8b91f6b authored by Haichao Zhu's avatar Haichao Zhu Committed by GitHub
Browse files

improve mimax-m2 rmsnorm precision (#12186)

parent 959d1ab8
......@@ -122,7 +122,7 @@ class MiniMaxM2RMSNormTP(nn.Module):
# Normalize and apply local weight shard
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
x = (x * self.weight).to(orig_dtype)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment