Commit feadffce authored by laibao's avatar laibao
Browse files

[BUGFIX] 修复 compressed tensors FP8 MoE 路径未透传 i_q/i_s 参数的问题

parent 7bf17aa2
......@@ -1111,8 +1111,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.kernel is not None
......@@ -1132,6 +1135,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
quant_config=self.moe_quant_config,
use_fused_gate=use_fused_gate,
use_nn_moe=False,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
......
......@@ -199,6 +199,8 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
):
return fused_experts_impl_fp8_marlin(
......@@ -220,7 +222,9 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
def apply(
self,
......@@ -243,6 +247,8 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......@@ -259,7 +265,10 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output, )
shared_output=shared_output,
i_q=i_q,
i_s=i_s,
)
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment