Unverified Commit 3c2c9f6c authored by Jiaqi Gu's avatar Jiaqi Gu Committed by GitHub
Browse files

[Bug] Fix input arguments of flashinfer_trtllm_moe (#9317)

parent a31ea448
......@@ -932,11 +932,11 @@ class FlashInferFusedMoE(FusedMoE):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
assert self.quant_method is not None
assert (
self.renormalize
topk_output.topk_config.renormalize
), "Renormalize is required for flashinfer blockscale fp8 moe"
assert (
self.num_fused_shared_experts == 0
......
......@@ -85,8 +85,8 @@ if _is_npu:
class TopKConfig:
top_k: int
use_grouped_topk: bool = False
topk_group: int = 0
num_expert_group: int = 0
topk_group: Optional[int] = None
num_expert_group: Optional[int] = None
renormalize: bool = True
num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None
......@@ -189,8 +189,8 @@ class TopK(CustomOp):
top_k: int,
*,
use_grouped_topk: bool = False,
topk_group: int = 0,
num_expert_group: int = 0,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
renormalize: bool = True,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
......@@ -427,8 +427,8 @@ def grouped_topk_gpu(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
......@@ -492,8 +492,8 @@ def grouped_topk_cpu(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
......@@ -522,8 +522,8 @@ def biased_grouped_topk_impl(
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
......@@ -615,8 +615,8 @@ def biased_grouped_topk_gpu(
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
......@@ -690,8 +690,8 @@ def biased_grouped_topk_cpu(
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
compiled: bool = True,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
......
......@@ -445,7 +445,6 @@ class Fp8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
......@@ -1087,7 +1086,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
activation = moe_runner_config.activation
routed_scaling_factor = moe_runner_config.routed_scaling_factor
......@@ -1105,9 +1103,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
assert (
topk_config.num_expert_group is not None
and topk_config.topk_group is not None
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
if topk_config.correction_bias is None:
correction_bias = topk_config.correction_bias.to(x.dtype)
else:
correction_bias = None
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=layer.correction_bias.to(x.dtype),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight,
......@@ -1121,9 +1128,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=routed_scaling_factor,
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], layer.top_k, layer.num_experts
x.shape[0], topk_config.top_k, layer.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment