Commit 612defae authored by Ziminli's avatar Ziminli
Browse files

issue/423: improve the precision of the torch implementation of rms_norm

parent 19d60bf8
......@@ -59,12 +59,10 @@ NUM_ITERATIONS = 1000
def rms_norm(ans, x, w, eps):
torch.pow(x, 2, out=ans)
mean = torch.mean(ans, dim=-1, keepdim=True)
mean.add_(eps)
torch.rsqrt(mean, out=mean)
torch.mul(x, mean, out=ans)
ans.mul_(w)
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
scale = hidden_states.pow(2).mean(-1, keepdim=True).add_(eps).rsqrt_()
ans.set_((hidden_states.mul_(scale).mul_(w)).to(input_dtype))
def test(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment