Unverified Commit 499f5e62 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix one missing arg in DeepEP (#6878)

parent 81964328
......@@ -180,6 +180,9 @@ class EPMoE(torch.nn.Module):
self.layer_id = layer_id
self.num_experts = num_experts
assert self.num_experts % self.tp_size == 0
assert (
num_fused_shared_experts == 0
), "num_fused_shared_experts is not supported in EP"
self.num_experts_per_partition = self.num_experts // self.tp_size
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
......@@ -191,7 +194,6 @@ class EPMoE(torch.nn.Module):
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
......@@ -252,7 +254,6 @@ class EPMoE(torch.nn.Module):
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
......@@ -886,6 +887,7 @@ class DeepEPMoE(EPMoE):
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
......@@ -897,23 +899,24 @@ class DeepEPMoE(EPMoE):
deepep_mode: DeepEPMode = DeepEPMode.auto,
):
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
layer_id,
params_dtype,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
prefix,
correction_bias,
custom_routing_function,
activation,
routed_scaling_factor,
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
params_dtype=params_dtype,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
correction_bias=correction_bias,
custom_routing_function=custom_routing_function,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
self.deepep_mode = deepep_mode
if self.deepep_mode.enable_low_latency():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment