Unverified Commit 5c7aa009 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix EPLB algorithm fail to run when using 3 nodes for prefill (#6629)

parent fe386aca
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from typing import Literal, Tuple
from typing import Literal, Optional, Tuple
import torch
......@@ -257,11 +257,15 @@ def rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_groups: int,
num_groups: Optional[int],
num_nodes: int,
phase: Literal["prefill", "decode"],
phase: Literal["prefill", "decode", "null"],
):
if phase == "prefill":
if (
(phase == "prefill")
and (num_groups is not None)
and (num_groups % num_nodes == 0)
):
return prefill_rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
......@@ -269,10 +273,8 @@ def rebalance_experts(
num_groups=num_groups,
num_nodes=num_nodes,
)
if phase == "decode":
return decode_rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
raise NotImplementedError
return decode_rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
......@@ -135,10 +135,6 @@ class ExpertLocationMetadata:
model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"]
phase = server_args.disaggregation_mode
if phase == "null" or model_config_for_expert_location.num_groups is None:
phase = "decode"
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
deepseek_eplb.rebalance_experts(
tokens_per_expert=logical_count,
......@@ -146,7 +142,7 @@ class ExpertLocationMetadata:
num_local_physical_experts=num_physical_experts // common["ep_size"],
num_groups=model_config_for_expert_location.num_groups,
num_nodes=server_args.nnodes,
phase=phase,
phase=server_args.disaggregation_mode,
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment