Unverified Commit b8401a9b authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Bugfix] Fix RMS norm + quant fusion on DeepGEMM UE8M0 path for B200 (#40552)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 9c271f94
......@@ -51,6 +51,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
......@@ -317,6 +318,26 @@ def test_fusion_rmsnorm_quant(
):
pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm")
# TODO(quant-rms-fusion): DeepGEMM UE8M0 activation quant on B200 lowers
# to a packed int32-scale op (per_token_group_quant_fp8_packed_for_deepgemm),
# but the rms+quant fusion pattern only matches the fp32-scale variant, so
# the fused output gets a mismatched scale layout and produces NaN. Only
# reproduces on bf16 (DeepGEMM UE8M0 on B200 is bf16-only).
# To re-enable: make rms_norm_per_block_quant emit packed UE8M0 scales
# and extend the fusion pattern to rewrite the packed activation quant.
deepgemm_kernels = (
DeepGemmFp8BlockScaledMMKernel,
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
)
if (
dtype == torch.bfloat16
and force_kernel in deepgemm_kernels
and is_deep_gemm_e8m0_used()
):
pytest.skip(
"rms+quant fusion does not yet match the packed UE8M0 DeepGEMM path"
)
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
......
......@@ -1826,6 +1826,7 @@ class TestFP8Layer(torch.nn.Module):
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
self.weight_scale = None
self.weight_block_size = [block_size, block_size]
if transpose_weights:
self.weight = self.weight.t()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment