"tests/entrypoints/llm/test_gpu_utilization.py" did not exist on "efbce85f4d375d7851a491a0126a224e25d9f91d"
Unverified Commit 6b7a8118 authored by amirkl94's avatar amirkl94 Committed by GitHub
Browse files

Bugfix: Cutlass FP8 FusedMoE bad scaling factors (#27255)


Signed-off-by: default avatarAmir Klein <203507526+amirkl94@users.noreply.github.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent b57789b6
......@@ -6,7 +6,10 @@ import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
......@@ -22,10 +25,10 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
100
90
):
pytest.skip(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
"Supported for sm >= 90",
allow_module_level=True,
)
......@@ -131,6 +134,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk: int,
monkeypatch,
):
if not current_platform.has_device_capability(100):
pytest.skip("Test is only supported for sm >= 100")
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
......@@ -184,9 +189,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
@pytest.mark.skip(
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
......@@ -216,9 +218,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
w2_scale=td.w2_weight_scale,
g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
a1_scale=td.a1_scale,
a1_gscale=td.a1_scale,
a2_scale=td.a2_scale,
a2_gscale=1.0 / td.a2_scale,
per_act_token_quant=False,
)
......@@ -238,6 +244,12 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.dp_size = 1
def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
return quant_config
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
td.hidden_states,
td.layer,
......
......@@ -463,6 +463,10 @@ def fp8_w8a8_moe_quant_config(
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: list[int] | None = None,
a1_gscale: torch.Tensor | None = None,
a2_gscale: torch.Tensor | None = None,
g1_alphas: torch.Tensor | None = None,
g2_alphas: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and fp8 weights.
......@@ -470,9 +474,13 @@ def fp8_w8a8_moe_quant_config(
return FusedMoEQuantConfig.make(
torch.float8_e4m3fn,
w1_scale=w1_scale,
g1_alphas=g1_alphas,
w2_scale=w2_scale,
g2_alphas=g2_alphas,
a1_scale=a1_scale,
a1_gscale=a1_gscale,
a2_scale=a2_scale,
a2_gscale=a2_gscale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
......
......@@ -170,7 +170,7 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
if not self.use_dp:
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
return a1, None, None, topk_ids, topk_weights
a1q, a1q_scale = moe_kernel_quantize_input(
......@@ -181,6 +181,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
......
......@@ -567,9 +567,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(),
w2_scale=layer.w2_weight_scale,
g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(),
a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a2_gscale=1.0 / layer.w2_input_scale,
per_act_token_quant=False,
)
......@@ -1138,8 +1142,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support,
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import (
detect_nvfp4_moe_support, # noqa: E501
)
super().__init__(moe)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment