Commit 15ebd0ce authored by René Honig's avatar René Honig Committed by khluu
Browse files

fix: Add SM120 (RTX Blackwell) support for FlashInfer CUTLASS NVFP4 MoE kernels (#33417)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
(cherry picked from commit 07978117)
parent 29152683
......@@ -649,7 +649,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_current_device() -> bool:
return current_platform.has_device_capability((10, 0))
p = current_platform
return p.is_cuda() and (
p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
......
......@@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_device_capability_family(100)
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
......
......@@ -91,11 +91,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return (
current_platform.is_cuda()
p.is_cuda()
and (
current_platform.is_device_capability((9, 0))
or current_platform.is_device_capability_family(100)
p.is_device_capability(90)
or p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
and has_flashinfer_cutlass_fused_moe()
)
......@@ -109,29 +112,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p = current_platform
scheme = (weight_key, activation_key)
# The following are supported by FlashInferExperts:
return (
# unquantized and fp8 static per-tensor on 9.0+
(
scheme
in [
(None, None),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
and p.has_device_capability(90)
)
# fp8 block-scale on 9.0
or (
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
and (p.is_device_capability((9, 0)))
scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
and p.is_device_capability(90)
)
# nvfp4 on 10.0+
or (
(scheme == (kNvfp4Static, kNvfp4Dynamic))
and (p.is_device_capability_family(100))
scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
)
)
......
......@@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
......
......@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
......@@ -25,10 +24,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
......@@ -39,8 +34,6 @@ logger = init_logger(__name__)
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
......@@ -126,26 +119,6 @@ def is_supported_config_trtllm(
return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability_family(100)
)
def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment