Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b8401a9b
Unverified
Commit
b8401a9b
authored
Apr 22, 2026
by
Lucas Kabela
Committed by
GitHub
Apr 22, 2026
Browse files
[Bugfix] Fix RMS norm + quant fusion on DeepGEMM UE8M0 path for B200 (#40552)
Signed-off-by:
Lucas Kabela
<
lucaskabela@meta.com
>
parent
9c271f94
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
0 deletions
+22
-0
tests/compile/passes/test_fusion.py
tests/compile/passes/test_fusion.py
+21
-0
tests/utils.py
tests/utils.py
+1
-0
No files found.
tests/compile/passes/test_fusion.py
View file @
b8401a9b
...
@@ -51,6 +51,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -51,6 +51,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
(
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
is_deep_gemm_supported
,
)
)
...
@@ -317,6 +318,26 @@ def test_fusion_rmsnorm_quant(
...
@@ -317,6 +318,26 @@ def test_fusion_rmsnorm_quant(
):
):
pytest
.
skip
(
"Unsupported group shape 64 for CUTLASS/DeepGemm"
)
pytest
.
skip
(
"Unsupported group shape 64 for CUTLASS/DeepGemm"
)
# TODO(quant-rms-fusion): DeepGEMM UE8M0 activation quant on B200 lowers
# to a packed int32-scale op (per_token_group_quant_fp8_packed_for_deepgemm),
# but the rms+quant fusion pattern only matches the fp32-scale variant, so
# the fused output gets a mismatched scale layout and produces NaN. Only
# reproduces on bf16 (DeepGEMM UE8M0 on B200 is bf16-only).
# To re-enable: make rms_norm_per_block_quant emit packed UE8M0 scales
# and extend the fusion pattern to rewrite the packed activation quant.
deepgemm_kernels
=
(
DeepGemmFp8BlockScaledMMKernel
,
FlashInferFp8DeepGEMMDynamicBlockScaledKernel
,
)
if
(
dtype
==
torch
.
bfloat16
and
force_kernel
in
deepgemm_kernels
and
is_deep_gemm_e8m0_used
()
):
pytest
.
skip
(
"rms+quant fusion does not yet match the packed UE8M0 DeepGEMM path"
)
custom_ops
=
[]
custom_ops
=
[]
if
enable_rms_norm_custom_op
:
if
enable_rms_norm_custom_op
:
custom_ops
.
append
(
"+rms_norm"
)
custom_ops
.
append
(
"+rms_norm"
)
...
...
tests/utils.py
View file @
b8401a9b
...
@@ -1826,6 +1826,7 @@ class TestFP8Layer(torch.nn.Module):
...
@@ -1826,6 +1826,7 @@ class TestFP8Layer(torch.nn.Module):
self
.
weight
=
torch
.
rand
(
weight_shape
).
to
(
dtype
=
FP8_DTYPE
)
self
.
weight
=
torch
.
rand
(
weight_shape
).
to
(
dtype
=
FP8_DTYPE
)
self
.
input_scale
=
None
self
.
input_scale
=
None
self
.
weight_scale
=
None
self
.
weight_scale
=
None
self
.
weight_block_size
=
[
block_size
,
block_size
]
if
transpose_weights
:
if
transpose_weights
:
self
.
weight
=
self
.
weight
.
t
()
self
.
weight
=
self
.
weight
.
t
()
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment