Unverified Commit 0ca18117 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support fake perfectly balanced EP dispatch algorithm (#6571)

parent 2c3a6fe1
......@@ -18,6 +18,7 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.managers import expert_location_dispatch
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
......@@ -310,6 +311,15 @@ def select_experts(
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits,
correction_bias=correction_bias,
info=expert_location_dispatch_info,
)
)
# DeepSeek V2/V3/R1 series models use grouped_top_k
if use_grouped_topk:
assert topk_group is not None
......
......@@ -55,6 +55,18 @@ class ExpertLocationDispatchInfo:
)
def transform_select_experts_inputs(
router_logits: torch.Tensor,
correction_bias: Optional[torch.Tensor],
info: Optional[ExpertLocationDispatchInfo],
):
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
router_logits = torch.randn_like(router_logits)
if correction_bias is not None:
correction_bias = torch.zeros_like(correction_bias)
return router_logits, correction_bias
def topk_ids_logical_to_physical(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
......@@ -63,7 +75,7 @@ def topk_ids_logical_to_physical(
if info.ep_dispatch_algorithm == "static":
return _topk_ids_logical_to_physical_static(topk_ids, info)
if info.ep_dispatch_algorithm == "dynamic":
if info.ep_dispatch_algorithm in ["dynamic", "fake"]:
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
raise NotImplementedError
......
......@@ -172,7 +172,7 @@ class ServerArgs:
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
init_expert_location: str = "trivial"
enable_eplb: bool = False
eplb_rebalance_num_iterations: int = 1000
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment