Unverified Commit 2e693f48 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Perf] Add TRTLLM FP8 MoE Modular Kernel (#36307)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 7f1f36bf
......@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
TrtLlmFp8Experts,
TrtLlmFp8ExpertsMonolithic,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
......@@ -247,7 +247,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
allow_new_interface=True,
use_monolithic=True,
),
TrtLlmFp8Experts(
TrtLlmFp8ExpertsMonolithic(
moe_config=td.layer.moe,
quant_config=quant_config,
),
......
......@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
......@@ -11,6 +12,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
......@@ -22,10 +26,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
class TrtLlmFp8ExpertsBase:
"""
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
Fp8 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic
interfaces.
"""
def __init__(
......@@ -33,8 +40,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
......@@ -44,6 +49,173 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.quant_config = quant_config
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports modular interface.
"""
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
return (workspace1, workspace2, output)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
import flashinfer
# Pack topk_ids and topk_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)
# trtllm_fp8_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils
if fi_utils._is_fi_autotuning:
return
assert a1q_scale is not None
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
# output tensor in-place so we need to manually copy the result to the
# output tensor
# https://github.com/flashinfer-ai/flashinfer/issues/2703
result = flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe(
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
use_shuffled_weight=False,
weight_layout=0,
# output=output,
)
output.copy_(result)
class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolithic):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
# Make additional scales for per-tensor interface.
if self.quant_config.is_per_tensor:
w1_scale = self.quant_config.w1_scale
......@@ -63,22 +235,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
......@@ -91,11 +247,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
......@@ -123,36 +274,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
else:
raise ValueError("Unsupported quantization scheme.")
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_naive_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def _apply_per_block(
self,
hidden_states: torch.Tensor,
......
......@@ -104,83 +104,84 @@ def _get_priority_backends(
def backend_to_kernel_cls(
backend: Fp8MoeBackend,
) -> type[mk.FusedMoEExperts]:
) -> list[type[mk.FusedMoEExperts]]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
TrtLlmFp8Experts,
TrtLlmFp8ExpertsModular,
TrtLlmFp8ExpertsMonolithic,
)
return TrtLlmFp8Experts
return [TrtLlmFp8ExpertsMonolithic, TrtLlmFp8ExpertsModular]
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return FlashInferExperts
return [FlashInferExperts]
elif backend == Fp8MoeBackend.DEEPGEMM:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
return TritonOrDeepGemmExperts
return [TritonOrDeepGemmExperts]
elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
return BatchedDeepGemmExperts
return [BatchedDeepGemmExperts]
elif backend == Fp8MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return MarlinExperts
return [MarlinExperts]
elif backend == Fp8MoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
return TritonExperts
return [TritonExperts]
elif backend == Fp8MoeBackend.BATCHED_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
return BatchedTritonExperts
return [BatchedTritonExperts]
elif backend == Fp8MoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
return AiterExperts
return [AiterExperts]
elif backend == Fp8MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
TritonOrCutlassExperts,
)
return TritonOrCutlassExperts
return [TritonOrCutlassExperts]
elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8,
)
return CutlassBatchedExpertsFp8
return [CutlassBatchedExpertsFp8]
elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExpertsFp8,
)
return XPUExpertsFp8
return [XPUExpertsFp8]
else:
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
......@@ -215,8 +216,9 @@ def select_fp8_moe_backend(
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)[0]
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key)
......@@ -256,7 +258,7 @@ def select_fp8_moe_backend(
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
......@@ -312,7 +314,7 @@ def select_fp8_moe_backend(
raise ValueError(
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
)
k_cls = backend_to_kernel_cls(backend)
k_cls = backend_to_kernel_cls(backend)[0]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
......@@ -322,7 +324,7 @@ def select_fp8_moe_backend(
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
k_cls = backend_to_kernel_cls(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
......@@ -382,7 +384,7 @@ def select_fp8_moe_backend(
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
k_cls = backend_to_kernel_cls(backend)
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
......@@ -390,7 +392,6 @@ def select_fp8_moe_backend(
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment