Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f38ca71c
Unverified
Commit
f38ca71c
authored
Sep 04, 2025
by
PanZezhong1725
Committed by
GitHub
Sep 04, 2025
Browse files
issue/423: improve the precision of the torch implementation of rms_norm
parents
c7373fee
612defae
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
6 deletions
+4
-6
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+4
-6
No files found.
test/infiniop/rms_norm.py
View file @
f38ca71c
...
...
@@ -59,12 +59,10 @@ NUM_ITERATIONS = 1000
def
rms_norm
(
ans
,
x
,
w
,
eps
):
torch
.
pow
(
x
,
2
,
out
=
ans
)
mean
=
torch
.
mean
(
ans
,
dim
=-
1
,
keepdim
=
True
)
mean
.
add_
(
eps
)
torch
.
rsqrt
(
mean
,
out
=
mean
)
torch
.
mul
(
x
,
mean
,
out
=
ans
)
ans
.
mul_
(
w
)
input_dtype
=
x
.
dtype
hidden_states
=
x
.
to
(
torch
.
float32
)
scale
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
).
add_
(
eps
).
rsqrt_
()
ans
.
set_
((
hidden_states
.
mul_
(
scale
).
mul_
(
w
)).
to
(
input_dtype
))
def
test
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment