Unverified Commit f6c0009a authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bugfix] Fix Broken ModelOpt NVFP4 MoE (#31742)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent 776ca1e1
...@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_pplx from vllm.utils.import_utils import has_deep_ep, has_pplx
...@@ -80,17 +77,12 @@ def maybe_make_prepare_finalize( ...@@ -80,17 +77,12 @@ def maybe_make_prepare_finalize(
prepare_finalize: FusedMoEPrepareAndFinalize | None = None prepare_finalize: FusedMoEPrepareAndFinalize | None = None
if moe.use_flashinfer_cutlass_kernels: # TODO(rob): update this as part of the MoE refactor.
assert quant_config is not None assert not moe.use_flashinfer_cutlass_kernels, (
use_deepseek_fp8_block_scale = ( "Must be created in modelopt.py or fp8.py"
quant_config is not None and quant_config.is_block_quantized
)
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
) )
elif moe.use_pplx_kernels: if moe.use_pplx_kernels:
assert quant_config is not None assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
......
...@@ -241,7 +241,9 @@ def flashinfer_cutlass_moe_fp4( ...@@ -241,7 +241,9 @@ def flashinfer_cutlass_moe_fp4(
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel( fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False), create_flashinfer_prepare_finalize(
use_dp=False, use_nvfp4=True, enable_alltoallv=False
),
FlashInferExperts( FlashInferExperts(
out_dtype=hidden_states.dtype, out_dtype=hidden_states.dtype,
quant_config=quant_config, quant_config=quant_config,
......
...@@ -48,6 +48,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -48,6 +48,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8, apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
register_moe_scaling_factors, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, rotate_flashinfer_fp8_moe_weights,
...@@ -149,7 +150,7 @@ def get_fp8_moe_backend( ...@@ -149,7 +150,7 @@ def get_fp8_moe_backend(
if block_quant and current_platform.is_device_capability_family(100): if block_quant and current_platform.is_device_capability_family(100):
raise ValueError( raise ValueError(
"FlashInfer FP8 MoE throughput backend does not " "FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use " "support block quantization on SM100. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency " "VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead." "instead."
) )
...@@ -1102,6 +1103,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1102,6 +1103,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):
return None return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables) return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
......
...@@ -46,6 +46,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( ...@@ -46,6 +46,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
apply_flashinfer_per_tensor_scale_fp8, apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
flashinfer_cutlass_moe_fp8, flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend, get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf, is_flashinfer_supporting_global_sf,
...@@ -750,6 +751,17 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -750,6 +751,17 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# TRT LLM not supported with all2all yet. # TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
return None return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe,
use_deepseek_fp8_block_scale=False,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
return super().maybe_make_prepare_finalize(routing_tables) return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl( def select_gemm_impl(
...@@ -1444,6 +1456,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1444,6 +1456,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.allow_flashinfer self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
): ):
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
# For now, fp4 moe only works with the flashinfer dispatcher. # For now, fp4 moe only works with the flashinfer dispatcher.
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
self.moe self.moe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment