"vscode:/vscode.git/clone" did not exist on "2faad08362ff50f254de27cb3c54272b9f3af4b8"
Unverified Commit 1af6f78a authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Perf] Change Trtllm fp8 MoE to use Shuffled Weights and BlockMajorK Layout (#38993)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 228023b3
...@@ -112,6 +112,24 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): ...@@ -112,6 +112,24 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
] ]
return (weight_key, activation_key) in SUPPORTED_W_A return (weight_key, activation_key) in SUPPORTED_W_A
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""Override to handle 4D BlockMajorK weights (E, K/bk, Mn, bk)."""
if w1.dim() == 4:
# BlockMajorK: (E, K/bk, Mn, bk)
E = w1.shape[0]
N = w1.shape[2]
K = a1.size(-1)
M = a1.size(0) if a1.dim() == 2 else a1.size(1)
topk = topk_ids.size(1)
return E, M, N, K, topk
return super().moe_problem_size(a1, w1, w2, topk_ids)
def workspace_shapes( def workspace_shapes(
self, self,
M: int, M: int,
...@@ -152,7 +170,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): ...@@ -152,7 +170,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
import flashinfer import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout
# Pack topk ids and weights into format expected by the kernel. # Pack topk ids and weights into format expected by the kernel.
packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
...@@ -170,10 +188,12 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): ...@@ -170,10 +188,12 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
if is_mxfp8: if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8 fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True use_shuffled_weight = True
weight_layout = WeightLayout.MajorK
hidden_states_scale = a1q_scale hidden_states_scale = a1q_scale
else: else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False use_shuffled_weight = True
weight_layout = WeightLayout.BlockMajorK
hidden_states_scale = a1q_scale.t().contiguous() hidden_states_scale = a1q_scale.t().contiguous()
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the # `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
...@@ -199,7 +219,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): ...@@ -199,7 +219,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
routed_scaling_factor=None, routed_scaling_factor=None,
routing_method_type=1, routing_method_type=1,
use_shuffled_weight=use_shuffled_weight, use_shuffled_weight=use_shuffled_weight,
weight_layout=0, weight_layout=weight_layout,
fp8_quantization_type=fp8_quant_type, fp8_quantization_type=fp8_quant_type,
# output=output, # output=output,
) )
...@@ -322,7 +342,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ...@@ -322,7 +342,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
topk_group: int | None = None, topk_group: int | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
import flashinfer import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType from flashinfer.fused_moe import Fp8QuantizationType, WeightLayout
assert not apply_router_weight_on_input assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU assert activation == MoEActivation.SILU
...@@ -342,10 +362,12 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ...@@ -342,10 +362,12 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
if is_mxfp8: if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8 fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True use_shuffled_weight = True
weight_layout = WeightLayout.MajorK
hidden_states_scale = a1q_scale hidden_states_scale = a1q_scale
else: else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False use_shuffled_weight = True
weight_layout = WeightLayout.BlockMajorK
hidden_states_scale = a1q_scale.t().contiguous() hidden_states_scale = a1q_scale.t().contiguous()
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
...@@ -367,6 +389,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ...@@ -367,6 +389,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type, routing_method_type=self.routing_method_type,
use_shuffled_weight=use_shuffled_weight, use_shuffled_weight=use_shuffled_weight,
weight_layout=weight_layout,
fp8_quantization_type=fp8_quant_type, fp8_quantization_type=fp8_quant_type,
) )
......
...@@ -305,6 +305,39 @@ def align_fp8_moe_weights_for_fi( ...@@ -305,6 +305,39 @@ def align_fp8_moe_weights_for_fi(
return padded_w13, padded_w2, padded_intermediate return padded_w13, padded_w2, padded_intermediate
def _shuffle_deepseek_fp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Preprocess DeepSeek FP8 block-scale weights for the FlashInfer TRT-LLM
kernel using the shuffle + BlockMajorK layout variant.
Returns 4D weight tensors in BlockMajorK layout
(E, K/block_k, Mn, block_k)
"""
from flashinfer import shuffle_matrix_a
from flashinfer.fused_moe import convert_to_block_layout
epilogue_tile_m = 64
block_k = 128
num_experts = w13.shape[0]
w13_shuffled: list[torch.Tensor] = []
w2_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
t13 = shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)
t13 = convert_to_block_layout(t13, block_k)
w13_shuffled.append(t13)
t2 = shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)
t2 = convert_to_block_layout(t2, block_k)
w2_shuffled.append(t2)
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
return w13_out, w2_out
def _shuffle_mxfp8_moe_weights( def _shuffle_mxfp8_moe_weights(
w13: torch.Tensor, w13: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -405,6 +438,7 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -405,6 +438,7 @@ def prepare_fp8_moe_layer_for_fi(
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
) )
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8 is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
is_deepseek_fp8 = block_quant and not is_mxfp8
is_gated = layer.activation.is_gated is_gated = layer.activation.is_gated
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle. # MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
...@@ -447,6 +481,10 @@ def prepare_fp8_moe_layer_for_fi( ...@@ -447,6 +481,10 @@ def prepare_fp8_moe_layer_for_fi(
if block_quant: if block_quant:
w13_scale = swap_w13_to_w31(w13_scale) w13_scale = swap_w13_to_w31(w13_scale)
# DeepSeekFp8 TRT-LLM: shuffle weights into BlockMajorK layout.
if is_deepseek_fp8 and is_trtllm:
w13, w2 = _shuffle_deepseek_fp8_moe_weights(w13, w2)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. # and registration of alpha scales.
if is_trtllm and not block_quant: if is_trtllm and not block_quant:
......
...@@ -13,7 +13,7 @@ import vllm.envs as envs ...@@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
) )
...@@ -168,14 +168,12 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ...@@ -168,14 +168,12 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
): ):
return False return False
if not isinstance(module.quant_method, FusedMoEModularMethod): moe_kernel = getattr(module.quant_method, "moe_kernel", None)
# modular kernels could invoke deep_gemm_moe_fp8 if moe_kernel is None:
return True return False
# Further check if the ModularKernel implementation uses the DeepGemmExperts fused_experts = moe_kernel.impl.fused_experts
return isinstance( return isinstance(fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
module.quant_method.moe_kernel, (DeepGemmExperts, TritonOrDeepGemmExperts)
)
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment