Commit 15ebd0ce authored by René Honig's avatar René Honig Committed by khluu
Browse files

fix: Add SM120 (RTX Blackwell) support for FlashInfer CUTLASS NVFP4 MoE kernels (#33417)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
(cherry picked from commit 07978117)
parent 29152683
...@@ -649,7 +649,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -649,7 +649,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return current_platform.has_device_capability((10, 0)) p = current_platform
return p.is_cuda() and (
p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
......
...@@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
return current_platform.is_device_capability_family(100) p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
......
...@@ -91,11 +91,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -91,11 +91,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
p = current_platform
return ( return (
current_platform.is_cuda() p.is_cuda()
and ( and (
current_platform.is_device_capability((9, 0)) p.is_device_capability(90)
or current_platform.is_device_capability_family(100) or p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
) )
and has_flashinfer_cutlass_fused_moe() and has_flashinfer_cutlass_fused_moe()
) )
...@@ -109,29 +112,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -109,29 +112,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key: QuantKey | None, weight_key: QuantKey | None,
activation_key: QuantKey | None, activation_key: QuantKey | None,
) -> bool: ) -> bool:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p = current_platform p = current_platform
scheme = (weight_key, activation_key) scheme = (weight_key, activation_key)
# The following are supported by FlashInferExperts:
return ( return (
# unquantized and fp8 static per-tensor on 9.0+
( (
scheme scheme
in [ in [
(None, None), (None, None),
(kFp8StaticTensorSym, kFp8StaticTensorSym), (kFp8StaticTensorSym, kFp8StaticTensorSym),
] ]
and p.has_device_capability(90)
) )
# fp8 block-scale on 9.0
or ( or (
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)) scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
and (p.is_device_capability((9, 0))) and p.is_device_capability(90)
) )
# nvfp4 on 10.0+
or ( or (
(scheme == (kNvfp4Static, kNvfp4Dynamic)) scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
and (p.is_device_capability_family(100))
) )
) )
......
...@@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op ...@@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs.""" """Supports only Blackwell-family GPUs."""
p = current_platform p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100) return p.is_cuda() and p.is_device_capability_family(100)
......
...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING ...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -25,10 +24,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -25,10 +24,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
...@@ -39,8 +34,6 @@ logger = init_logger(__name__) ...@@ -39,8 +34,6 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize", "build_flashinfer_fp4_cutlass_moe_prepare_finalize",
] ]
...@@ -126,26 +119,6 @@ def is_supported_config_trtllm( ...@@ -126,26 +119,6 @@ def is_supported_config_trtllm(
return True, None return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability_family(100)
)
def reorder_w1w3_to_w3w1( def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment