Unverified Commit ccfe5c00 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support redundant experts in expert parallel (#6461)

parent a071dc40
...@@ -163,8 +163,7 @@ class ExpertLocationMetadata: ...@@ -163,8 +163,7 @@ class ExpertLocationMetadata:
num_physical_experts = ( num_physical_experts = (
model_config_for_expert_location.num_logical_experts model_config_for_expert_location.num_logical_experts
# TODO pr-chain: enable this later + server_args.ep_num_redundant_experts
# + server_args.ep_num_redundant_experts
) )
ep_size = server_args.ep_size ep_size = server_args.ep_size
assert num_physical_experts % ep_size == 0 assert num_physical_experts % ep_size == 0
......
...@@ -90,6 +90,7 @@ global_server_args_dict = { ...@@ -90,6 +90,7 @@ global_server_args_dict = {
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"torchao_config": ServerArgs.torchao_config, "torchao_config": ServerArgs.torchao_config,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"ep_num_redundant_experts": ServerArgs.ep_num_redundant_experts,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -206,6 +206,7 @@ class ModelRunner: ...@@ -206,6 +206,7 @@ class ModelRunner:
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
"mm_attention_backend": server_args.mm_attention_backend, "mm_attention_backend": server_args.mm_attention_backend,
"ep_num_redundant_experts": server_args.ep_num_redundant_experts,
} }
) )
......
...@@ -243,7 +243,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -243,7 +243,9 @@ class DeepseekV2MoE(nn.Module):
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix)) self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts + self.n_share_experts_fusion, num_experts=config.n_routed_experts
+ self.n_share_experts_fusion
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
...@@ -285,7 +287,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -285,7 +287,10 @@ class DeepseekV2MoE(nn.Module):
if global_server_args_dict["enable_deepep_moe"]: if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size() self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.n_routed_experts self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
)
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group self.topk_group = config.topk_group
self.num_expert_group = config.n_group self.num_expert_group = config.n_group
...@@ -299,7 +304,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -299,7 +304,7 @@ class DeepseekV2MoE(nn.Module):
group=parallel_state.get_tp_group().device_group, group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k, router_topk=self.top_k,
permute_fusion=True, permute_fusion=True,
num_experts=config.n_routed_experts, num_experts=self.num_experts,
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
......
...@@ -170,6 +170,7 @@ class ServerArgs: ...@@ -170,6 +170,7 @@ class ServerArgs:
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
init_expert_location: str = "trivial" init_expert_location: str = "trivial"
expert_distribution_recorder_mode: Optional[ expert_distribution_recorder_mode: Optional[
...@@ -1273,6 +1274,12 @@ class ServerArgs: ...@@ -1273,6 +1274,12 @@ class ServerArgs:
default="auto", default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
) )
parser.add_argument(
"--ep-num-redundant-experts",
type=int,
default=ServerArgs.ep_num_redundant_experts,
help="Allocate this number of redundant experts in expert parallel.",
)
parser.add_argument( parser.add_argument(
"--ep-dispatch-algorithm", "--ep-dispatch-algorithm",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment