Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
612defae
Commit
612defae
authored
Sep 03, 2025
by
Ziminli
Browse files
issue/423: improve the precision of the torch implementation of rms_norm
parent
19d60bf8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
6 deletions
+4
-6
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+4
-6
No files found.
test/infiniop/rms_norm.py
View file @
612defae
...
@@ -59,12 +59,10 @@ NUM_ITERATIONS = 1000
...
@@ -59,12 +59,10 @@ NUM_ITERATIONS = 1000
def
rms_norm
(
ans
,
x
,
w
,
eps
):
def
rms_norm
(
ans
,
x
,
w
,
eps
):
torch
.
pow
(
x
,
2
,
out
=
ans
)
input_dtype
=
x
.
dtype
mean
=
torch
.
mean
(
ans
,
dim
=-
1
,
keepdim
=
True
)
hidden_states
=
x
.
to
(
torch
.
float32
)
mean
.
add_
(
eps
)
scale
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
).
add_
(
eps
).
rsqrt_
()
torch
.
rsqrt
(
mean
,
out
=
mean
)
ans
.
set_
((
hidden_states
.
mul_
(
scale
).
mul_
(
w
)).
to
(
input_dtype
))
torch
.
mul
(
x
,
mean
,
out
=
ans
)
ans
.
mul_
(
w
)
def
test
(
def
test
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment