Unverified Commit 2e693f48 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Perf] Add TRTLLM FP8 MoE Modular Kernel (#36307)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 7f1f36bf
...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
TrtLlmFp8Experts, TrtLlmFp8ExpertsMonolithic,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
...@@ -247,7 +247,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( ...@@ -247,7 +247,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
allow_new_interface=True, allow_new_interface=True,
use_monolithic=True, use_monolithic=True,
), ),
TrtLlmFp8Experts( TrtLlmFp8ExpertsMonolithic(
moe_config=td.layer.moe, moe_config=td.layer.moe,
quant_config=quant_config, quant_config=quant_config,
), ),
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
...@@ -11,6 +12,9 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -11,6 +12,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType, RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int, activation_to_flashinfer_int,
) )
...@@ -22,10 +26,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -22,10 +26,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__)
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): class TrtLlmFp8ExpertsBase:
""" """
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface. Fp8 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic
interfaces.
""" """
def __init__( def __init__(
...@@ -33,8 +40,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): ...@@ -33,8 +40,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
): ):
super().__init__(moe_config, quant_config)
self.routing_method_type = moe_config.routing_method self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = ( self.intermediate_size_per_partition = (
...@@ -44,6 +49,173 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): ...@@ -44,6 +49,173 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
self.local_num_experts = moe_config.num_local_experts self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.quant_config = quant_config
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports modular interface.
"""
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
return (workspace1, workspace2, output)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
import flashinfer
# Pack topk_ids and topk_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)
# trtllm_fp8_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils
if fi_utils._is_fi_autotuning:
return
assert a1q_scale is not None
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
# output tensor in-place so we need to manually copy the result to the
# output tensor
# https://github.com/flashinfer-ai/flashinfer/issues/2703
result = flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe(
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
use_shuffled_weight=False,
weight_layout=0,
# output=output,
)
output.copy_(result)
class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolithic):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
# Make additional scales for per-tensor interface. # Make additional scales for per-tensor interface.
if self.quant_config.is_per_tensor: if self.quant_config.is_per_tensor:
w1_scale = self.quant_config.w1_scale w1_scale = self.quant_config.w1_scale
...@@ -63,22 +235,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): ...@@ -63,22 +235,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
) )
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod @staticmethod
def _supports_quant_scheme( def _supports_quant_scheme(
weight_key: QuantKey | None, weight_key: QuantKey | None,
...@@ -91,11 +247,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): ...@@ -91,11 +247,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
] ]
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod @staticmethod
def _supports_routing_method( def _supports_routing_method(
routing_method: RoutingMethodType, routing_method: RoutingMethodType,
...@@ -123,36 +274,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): ...@@ -123,36 +274,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
else: else:
raise ValueError("Unsupported quantization scheme.") raise ValueError("Unsupported quantization scheme.")
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def _apply_per_block( def _apply_per_block(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -104,83 +104,84 @@ def _get_priority_backends( ...@@ -104,83 +104,84 @@ def _get_priority_backends(
def backend_to_kernel_cls( def backend_to_kernel_cls(
backend: Fp8MoeBackend, backend: Fp8MoeBackend,
) -> type[mk.FusedMoEExperts]: ) -> list[type[mk.FusedMoEExperts]]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
TrtLlmFp8Experts, TrtLlmFp8ExpertsModular,
TrtLlmFp8ExpertsMonolithic,
) )
return TrtLlmFp8Experts return [TrtLlmFp8ExpertsMonolithic, TrtLlmFp8ExpertsModular]
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, FlashInferExperts,
) )
return FlashInferExperts return [FlashInferExperts]
elif backend == Fp8MoeBackend.DEEPGEMM: elif backend == Fp8MoeBackend.DEEPGEMM:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
return TritonOrDeepGemmExperts return [TritonOrDeepGemmExperts]
elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM: elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts, BatchedDeepGemmExperts,
) )
return BatchedDeepGemmExperts return [BatchedDeepGemmExperts]
elif backend == Fp8MoeBackend.MARLIN: elif backend == Fp8MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts, MarlinExperts,
) )
return MarlinExperts return [MarlinExperts]
elif backend == Fp8MoeBackend.TRITON: elif backend == Fp8MoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, TritonExperts,
) )
return TritonExperts return [TritonExperts]
elif backend == Fp8MoeBackend.BATCHED_TRITON: elif backend == Fp8MoeBackend.BATCHED_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, BatchedTritonExperts,
) )
return BatchedTritonExperts return [BatchedTritonExperts]
elif backend == Fp8MoeBackend.AITER: elif backend == Fp8MoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts, AiterExperts,
) )
return AiterExperts return [AiterExperts]
elif backend == Fp8MoeBackend.VLLM_CUTLASS: elif backend == Fp8MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
TritonOrCutlassExperts, TritonOrCutlassExperts,
) )
return TritonOrCutlassExperts return [TritonOrCutlassExperts]
elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS: elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8, CutlassBatchedExpertsFp8,
) )
return CutlassBatchedExpertsFp8 return [CutlassBatchedExpertsFp8]
elif backend == Fp8MoeBackend.XPU: elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import ( from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExpertsFp8, XPUExpertsFp8,
) )
return XPUExpertsFp8 return [XPUExpertsFp8]
else: else:
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
...@@ -215,8 +216,9 @@ def select_fp8_moe_backend( ...@@ -215,8 +216,9 @@ def select_fp8_moe_backend(
Select the primary FP8 MoE backend Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
""" """
if config.is_lora_enabled: if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)[0]
# NOTE: the kernels are selected in the following order. # NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key) AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key)
...@@ -256,13 +258,13 @@ def select_fp8_moe_backend( ...@@ -256,13 +258,13 @@ def select_fp8_moe_backend(
activation_key: QuantKey | None, activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat, activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend) for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format k_cls, config, weight_key, activation_key, activation_format
) )
if supported: if supported:
logger.info_once(_make_log_backend(backend), scope="local") logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason)) raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user. # Handle explicit moe_backend from user.
...@@ -312,7 +314,7 @@ def select_fp8_moe_backend( ...@@ -312,7 +314,7 @@ def select_fp8_moe_backend(
raise ValueError( raise ValueError(
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE." f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
) )
k_cls = backend_to_kernel_cls(backend) k_cls = backend_to_kernel_cls(backend)[0]
return _return_or_raise( return _return_or_raise(
backend, config, weight_key, activation_key, activation_format backend, config, weight_key, activation_key, activation_format
) )
...@@ -322,23 +324,23 @@ def select_fp8_moe_backend( ...@@ -322,23 +324,23 @@ def select_fp8_moe_backend(
Fp8MoeBackend.FLASHINFER_TRTLLM, Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_CUTLASS,
]: ]:
k_cls = backend_to_kernel_cls(backend) for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, k_cls,
config, config,
weight_key, weight_key,
activation_key, activation_key,
activation_format, activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
) )
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError( raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no " "Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no "
"FlashInfer FP8 MoE backend supports the configuration." "FlashInfer FP8 MoE backend supports the configuration."
...@@ -382,20 +384,19 @@ def select_fp8_moe_backend( ...@@ -382,20 +384,19 @@ def select_fp8_moe_backend(
# Select kernels in order of backend. # Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS: for backend in AVAILABLE_BACKENDS:
k_cls = backend_to_kernel_cls(backend) for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config( supported, reason = k_cls.is_supported_config(
k_cls, k_cls,
config, config,
weight_key, weight_key,
activation_key, activation_key,
activation_format, activation_format,
) )
if supported:
if supported: logger.info_once(_make_log_backend(backend), scope="local")
logger.info_once(_make_log_backend(backend), scope="local") return backend, k_cls
return backend, k_cls else:
else: logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
# TODO(rob): per discussion with TPU team, we need a way to register # TODO(rob): per discussion with TPU team, we need a way to register
# MoE backends by OOT plugins, rather than having an explicit list # MoE backends by OOT plugins, rather than having an explicit list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment