Unverified Commit e85f8f09 authored by yzong-rh's avatar yzong-rh Committed by GitHub
Browse files

[Bug][MoE] Strengthen _supports_current_device() checks in the TRTLLM FP8,...


[Bug][MoE] Strengthen _supports_current_device() checks in the TRTLLM FP8, NVFP4, and FlashInfer CuteDSL MoE experts (#36728)
Signed-off-by: default avatarYifan Zong <yzong@redhat.com>
parent 5bf3c42d
...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked, flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize, scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize, silu_and_mul_scaled_nvfp4_experts_quantize,
) )
...@@ -60,7 +61,11 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): ...@@ -60,7 +61,11 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
p = current_platform p = current_platform
return p.is_cuda() and p.is_device_capability_family(100) return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
......
...@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kMxfp8Static, kMxfp8Static,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -61,8 +62,11 @@ class TrtLlmFp8ExpertsBase: ...@@ -61,8 +62,11 @@ class TrtLlmFp8ExpertsBase:
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs.""" """Supports only Blackwell-family GPUs."""
p = current_platform p = current_platform
# Add check flashinfer trtllm is available return (
return p.is_cuda() and p.is_device_capability_family(100) p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_trtllm_fused_moe()
)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
......
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Static, kNvfp4Static,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
class TrtLlmNvFp4ExpertsBase: class TrtLlmNvFp4ExpertsBase:
...@@ -80,7 +81,11 @@ class TrtLlmNvFp4ExpertsBase: ...@@ -80,7 +81,11 @@ class TrtLlmNvFp4ExpertsBase:
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs.""" """Supports only Blackwell-family GPUs."""
p = current_platform p = current_platform
return p.is_cuda() and p.is_device_capability_family(100) return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_trtllm_fused_moe()
)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
......
...@@ -244,7 +244,7 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool: ...@@ -244,7 +244,7 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
required_functions = [ required_functions = [
("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"), ("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
("flashinfer", "scaled_fp4_grouped_quantize"), ("flashinfer", "scaled_fp4_grouped_quantize"),
("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"), ("flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"),
] ]
for module_name, attr_name in required_functions: for module_name, attr_name in required_functions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment