Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
62904708
Unverified
Commit
62904708
authored
Mar 02, 2026
by
haosdent
Committed by
GitHub
Mar 01, 2026
Browse files
[Bugfix] Fix dtype mismatch in RMSNormGated.forward_native() during torch.compile (#35256)
Signed-off-by:
haosdent
<
haosdent@gmail.com
>
parent
72f4d162
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
4 deletions
+71
-4
tests/kernels/test_fla_layernorm_guard.py
tests/kernels/test_fla_layernorm_guard.py
+63
-1
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+8
-3
No files found.
tests/kernels/test_fla_layernorm_guard.py
View file @
62904708
...
@@ -74,7 +74,7 @@ def layer_norm_ref(
...
@@ -74,7 +74,7 @@ def layer_norm_ref(
return
out
.
to
(
dtype
)
return
out
.
to
(
dtype
)
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
# Test various M sizes to ensure rows_per_block logic works correctly
# Test various M sizes to ensure rows_per_block logic works correctly
NUM_TOKENS
=
[
NUM_TOKENS
=
[
1
,
1
,
...
@@ -380,6 +380,68 @@ def test_multidimensional_input(
...
@@ -380,6 +380,68 @@ def test_multidimensional_input(
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"has_gate"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
None
,
64
])
@
pytest
.
mark
.
parametrize
(
"norm_before_gate"
,
[
True
,
False
])
@
torch
.
inference_mode
()
def
test_rmsnorm_gated_forward_native_dtype
(
default_vllm_config
,
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
has_gate
:
bool
,
group_size
:
int
|
None
,
norm_before_gate
:
bool
,
):
"""Test that RMSNormGated.forward_native preserves input dtype."""
if
group_size
is
not
None
and
hidden_size
%
group_size
!=
0
:
pytest
.
skip
(
f
"hidden_size
{
hidden_size
}
not divisible by group_size
{
group_size
}
"
)
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
device
=
torch
.
device
(
"cuda:0"
)
set_random_seed
(
42
)
layer
=
RMSNormGated
(
hidden_size
,
eps
=
1e-5
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
device
=
device
,
dtype
=
dtype
,
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
z
=
(
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
if
has_gate
else
None
)
out
=
layer
.
forward_native
(
x
,
z
)
# Verify dtype preservation
assert
out
.
dtype
==
dtype
,
f
"Expected
{
dtype
}
, got
{
out
.
dtype
}
"
# Verify numerical correctness against reference
ref_out
=
rms_norm_ref
(
x
,
layer
.
weight
,
layer
.
bias
,
z
=
z
,
eps
=
1e-5
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
upcast
=
True
,
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# Run a quick smoke test
# Run a quick smoke test
test_layer_norm_fwd_basic
(
128
,
1024
,
torch
.
float16
,
42
,
False
)
test_layer_norm_fwd_basic
(
128
,
1024
,
torch
.
float16
,
42
,
False
)
...
...
vllm/model_executor/layers/layernorm.py
View file @
62904708
...
@@ -557,6 +557,11 @@ class RMSNormGated(CustomOp):
...
@@ -557,6 +557,11 @@ class RMSNormGated(CustomOp):
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z))
- norm_before_gate=False: out = norm(x * silu(z))
"""
"""
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
weight
=
self
.
weight
.
float
()
z
=
z
.
float
()
if
z
is
not
None
else
None
# Apply gating before normalization if needed
# Apply gating before normalization if needed
if
z
is
not
None
and
not
self
.
norm_before_gate
:
if
z
is
not
None
and
not
self
.
norm_before_gate
:
x
=
x
*
F
.
silu
(
z
)
x
=
x
*
F
.
silu
(
z
)
...
@@ -566,7 +571,7 @@ class RMSNormGated(CustomOp):
...
@@ -566,7 +571,7 @@ class RMSNormGated(CustomOp):
# Standard RMS norm across the last dimension
# Standard RMS norm across the last dimension
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x_normed
=
x
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
x_normed
=
x
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
out
=
x_normed
*
self
.
weight
out
=
x_normed
*
weight
else
:
else
:
# Group RMS norm
# Group RMS norm
from
einops
import
rearrange
from
einops
import
rearrange
...
@@ -574,13 +579,13 @@ class RMSNormGated(CustomOp):
...
@@ -574,13 +579,13 @@ class RMSNormGated(CustomOp):
x_group
=
rearrange
(
x
,
"... (g d) -> ... g d"
,
d
=
self
.
group_size
)
x_group
=
rearrange
(
x
,
"... (g d) -> ... g d"
,
d
=
self
.
group_size
)
variance
=
x_group
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
x_group
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x_normed
=
x_group
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
x_normed
=
x_group
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
out
=
rearrange
(
x_normed
,
"... g d -> ... (g d)"
)
*
self
.
weight
out
=
rearrange
(
x_normed
,
"... g d -> ... (g d)"
)
*
weight
# Apply gating after normalization if needed
# Apply gating after normalization if needed
if
z
is
not
None
and
self
.
norm_before_gate
:
if
z
is
not
None
and
self
.
norm_before_gate
:
out
=
out
*
F
.
silu
(
z
)
out
=
out
*
F
.
silu
(
z
)
return
out
.
to
(
x
.
dtype
)
return
out
.
to
(
orig_
dtype
)
def
forward_cuda
(
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
z
:
torch
.
Tensor
|
None
=
None
self
,
x
:
torch
.
Tensor
,
z
:
torch
.
Tensor
|
None
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment