Unverified Commit 1af6f78a authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Perf] Change Trtllm fp8 MoE to use Shuffled Weights and BlockMajorK Layout (#38993)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 228023b3
......@@ -112,6 +112,24 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
]
return (weight_key, activation_key) in SUPPORTED_W_A
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""Override to handle 4D BlockMajorK weights (E, K/bk, Mn, bk)."""
if w1.dim() == 4:
# BlockMajorK: (E, K/bk, Mn, bk)
E = w1.shape[0]
N = w1.shape[2]
K = a1.size(-1)
M = a1.size(0) if a1.dim() == 2 else a1.size(1)
topk = topk_ids.size(1)
return E, M, N, K, topk
return super().moe_problem_size(a1, w1, w2, topk_ids)
def workspace_shapes(
self,
M: int,
......@@ -152,7 +170,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
apply_router_weight_on_input: bool,
):
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout
# Pack topk ids and weights into format expected by the kernel.
packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
......@@ -170,10 +188,12 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True
weight_layout = WeightLayout.MajorK
hidden_states_scale = a1q_scale
else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False
use_shuffled_weight = True
weight_layout = WeightLayout.BlockMajorK
hidden_states_scale = a1q_scale.t().contiguous()
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
......@@ -199,7 +219,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
routed_scaling_factor=None,
routing_method_type=1,
use_shuffled_weight=use_shuffled_weight,
weight_layout=0,
weight_layout=weight_layout,
fp8_quantization_type=fp8_quant_type,
# output=output,
)
......@@ -322,7 +342,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
topk_group: int | None = None,
) -> torch.Tensor:
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout
assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU
......@@ -342,10 +362,12 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True
weight_layout = WeightLayout.MajorK
hidden_states_scale = a1q_scale
else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False
use_shuffled_weight = True
weight_layout = WeightLayout.BlockMajorK
hidden_states_scale = a1q_scale.t().contiguous()
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
......@@ -367,6 +389,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
use_shuffled_weight=use_shuffled_weight,
weight_layout=weight_layout,
fp8_quantization_type=fp8_quant_type,
)
......
......@@ -305,6 +305,39 @@ def align_fp8_moe_weights_for_fi(
return padded_w13, padded_w2, padded_intermediate
def _shuffle_deepseek_fp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Preprocess DeepSeek FP8 block-scale weights for the FlashInfer TRT-LLM
kernel using the shuffle + BlockMajorK layout variant.
Returns 4D weight tensors in BlockMajorK layout
(E, K/block_k, Mn, block_k)
"""
from flashinfer import shuffle_matrix_a
from flashinfer.fused_moe import convert_to_block_layout
epilogue_tile_m = 64
block_k = 128
num_experts = w13.shape[0]
w13_shuffled: list[torch.Tensor] = []
w2_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
t13 = convert_to_block_layout(t13, block_k)
w13_shuffled.append(t13)
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
t2 = convert_to_block_layout(t2, block_k)
w2_shuffled.append(t2)
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
return w13_out, w2_out
def _shuffle_mxfp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
......@@ -405,6 +438,7 @@ def prepare_fp8_moe_layer_for_fi(
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
is_deepseek_fp8 = block_quant and not is_mxfp8
is_gated = layer.activation.is_gated
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
......@@ -447,6 +481,10 @@ def prepare_fp8_moe_layer_for_fi(
if block_quant:
w13_scale = swap_w13_to_w31(w13_scale)
# DeepSeekFp8 TRT-LLM: shuffle weights into BlockMajorK layout.
if is_deepseek_fp8 and is_trtllm:
w13, w2 = _shuffle_deepseek_fp8_moe_weights(w13, w2)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales.
if is_trtllm and not block_quant:
......
......@@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
......@@ -168,14 +168,12 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
):
return False
if not isinstance(module.quant_method, FusedMoEModularMethod):
# modular kernels could invoke deep_gemm_moe_fp8
return True
moe_kernel = getattr(module.quant_method, "moe_kernel", None)
if moe_kernel is None:
return False
# Further check if the ModularKernel implementation uses the DeepGemmExperts
return isinstance(
module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts)
)
fused_experts = moe_kernel.impl.fused_experts
return isinstance(fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment