Unverified Commit 3c2c9f6c authored by Jiaqi Gu's avatar Jiaqi Gu Committed by GitHub
Browse files

[Bug] Fix input arguments of flashinfer_trtllm_moe (#9317)

parent a31ea448
...@@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE): ...@@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.use_flashinfer_trtllm_moe assert self.use_flashinfer_trtllm_moe
assert ( assert (
self.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe" ), "Only silu is supported for flashinfer blockscale fp8 moe"
assert self.quant_method is not None assert self.quant_method is not None
assert ( assert (
self.renormalize topk_output.topk_config.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe" ), "Renormalize is required for flashinfer blockscale fp8 moe"
assert ( assert (
self.num_fused_shared_experts == 0 self.num_fused_shared_experts == 0
......
...@@ -85,8 +85,8 @@ if _is_npu: ...@@ -85,8 +85,8 @@ if _is_npu:
class TopKConfig: class TopKConfig:
top_k: int top_k: int
use_grouped_topk: bool = False use_grouped_topk: bool = False
topk_group: int = 0 topk_group: Optional[int] = None
num_expert_group: int = 0 num_expert_group: Optional[int] = None
renormalize: bool = True renormalize: bool = True
num_fused_shared_experts: int = 0 num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None custom_routing_function: Optional[Callable] = None
...@@ -189,8 +189,8 @@ class TopK(CustomOp): ...@@ -189,8 +189,8 @@ class TopK(CustomOp):
top_k: int, top_k: int,
*, *,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: int = 0, topk_group: Optional[int] = None,
num_expert_group: int = 0, num_expert_group: Optional[int] = None,
renormalize: bool = True, renormalize: bool = True,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
...@@ -427,8 +427,8 @@ def grouped_topk_gpu( ...@@ -427,8 +427,8 @@ def grouped_topk_gpu(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: Optional[int] = None,
topk_group: int = 0, topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
...@@ -492,8 +492,8 @@ def grouped_topk_cpu( ...@@ -492,8 +492,8 @@ def grouped_topk_cpu(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: Optional[int] = None,
topk_group: int = 0, topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
...@@ -522,8 +522,8 @@ def biased_grouped_topk_impl( ...@@ -522,8 +522,8 @@ def biased_grouped_topk_impl(
correction_bias: torch.Tensor, correction_bias: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: Optional[int] = None,
topk_group: int = 0, topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
...@@ -615,8 +615,8 @@ def biased_grouped_topk_gpu( ...@@ -615,8 +615,8 @@ def biased_grouped_topk_gpu(
correction_bias: torch.Tensor, correction_bias: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: Optional[int] = None,
topk_group: int = 0, topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
...@@ -690,8 +690,8 @@ def biased_grouped_topk_cpu( ...@@ -690,8 +690,8 @@ def biased_grouped_topk_cpu(
correction_bias: torch.Tensor, correction_bias: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: Optional[int] = None,
topk_group: int = 0, topk_group: Optional[int] = None,
compiled: bool = True, compiled: bool = True,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
......
...@@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_marlin: if self.use_marlin:
return apply_fp8_marlin_linear( return apply_fp8_marlin_linear(
input=x, input=x,
...@@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_output: TopKOutput, topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig, moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: ) -> torch.Tensor:
activation = moe_runner_config.activation activation = moe_runner_config.activation
routed_scaling_factor = moe_runner_config.routed_scaling_factor routed_scaling_factor = moe_runner_config.routed_scaling_factor
...@@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# NOTE: scales of hidden states have to be transposed! # NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous() a_sf_t = a_sf.t().contiguous()
assert (
topk_config.num_expert_group is not None
and topk_config.topk_group is not None
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
if topk_config.correction_bias is None:
correction_bias = topk_config.correction_bias.to(x.dtype)
else:
correction_bias = None
return trtllm_fp8_block_scale_moe( return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32), routing_logits=router_logits.to(torch.float32),
routing_bias=layer.correction_bias.to(x.dtype), routing_bias=correction_bias,
hidden_states=a_q, hidden_states=a_q,
hidden_states_scale=a_sf_t, hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight, gemm1_weights=layer.w13_weight,
...@@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size=layer.w2_weight.shape[2], intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts, local_num_experts=layer.num_local_experts,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
tile_tokens_dim=get_tile_tokens_dim( tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], layer.top_k, layer.num_experts x.shape[0], topk_config.top_k, layer.num_experts
), ),
routing_method_type=2, # DeepSeek-styled routing method routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False, use_shuffled_weight=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment