Commit 612defae authored by Ziminli's avatar Ziminli
Browse files

issue/423: improve the precision of the torch implementation of rms_norm

parent 19d60bf8
...@@ -59,12 +59,10 @@ NUM_ITERATIONS = 1000 ...@@ -59,12 +59,10 @@ NUM_ITERATIONS = 1000
def rms_norm(ans, x, w, eps): def rms_norm(ans, x, w, eps):
torch.pow(x, 2, out=ans) input_dtype = x.dtype
mean = torch.mean(ans, dim=-1, keepdim=True) hidden_states = x.to(torch.float32)
mean.add_(eps) scale = hidden_states.pow(2).mean(-1, keepdim=True).add_(eps).rsqrt_()
torch.rsqrt(mean, out=mean) ans.set_((hidden_states.mul_(scale).mul_(w)).to(input_dtype))
torch.mul(x, mean, out=ans)
ans.mul_(w)
def test( def test(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment