Unverified Commit 6d0646da authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Fix breakage of using trtllm-gen fp8 moe (#8773)

parent 02bc1c7d
...@@ -673,66 +673,6 @@ class DeepEPMoE(EPMoE): ...@@ -673,66 +673,6 @@ class DeepEPMoE(EPMoE):
return down_output return down_output
class FlashInferEPMoE(EPMoE):
def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
assert (
self.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe"
assert (
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=self.correction_bias.to(hidden_states.dtype),
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=self.w13_weight,
gemm1_weights_scale=self.w13_weight_scale_inv,
gemm2_weights=self.w2_weight,
gemm2_weights_scale=self.w2_weight_scale_inv,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
intermediate_size=self.w2_weight.shape[2],
local_expert_offset=self.start_expert_id,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
tile_tokens_dim=get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def get_moe_impl_class(): def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if global_server_args_dict["moe_a2a_backend"].is_deepep():
return DeepEPMoE return DeepEPMoE
...@@ -752,8 +692,10 @@ def get_moe_impl_class(): ...@@ -752,8 +692,10 @@ def get_moe_impl_class():
except: except:
pass pass
if should_use_flashinfer_trtllm_moe():
return FlashInferFusedMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]: if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
return FusedMoE return FusedMoE
if get_moe_expert_parallel_world_size() > 1: if get_moe_expert_parallel_world_size() > 1:
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE return EPMoE
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE return FusedMoE
...@@ -763,8 +763,13 @@ class FlashInferFusedMoE(FusedMoE): ...@@ -763,8 +763,13 @@ class FlashInferFusedMoE(FusedMoE):
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
assert self.use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
assert self.quant_method is not None assert self.quant_method is not None
assert ( assert (
self.renormalize self.renormalize
...@@ -772,6 +777,14 @@ class FlashInferFusedMoE(FusedMoE): ...@@ -772,6 +777,14 @@ class FlashInferFusedMoE(FusedMoE):
assert ( assert (
self.num_fused_shared_experts == 0 self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits( final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self, layer=self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment