Unverified Commit 5c7aa009 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix EPLB algorithm fail to run when using 3 nodes for prefill (#6629)

parent fe386aca
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package # This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from typing import Literal, Tuple from typing import Literal, Optional, Tuple
import torch import torch
...@@ -257,11 +257,15 @@ def rebalance_experts( ...@@ -257,11 +257,15 @@ def rebalance_experts(
tokens_per_expert: torch.Tensor, tokens_per_expert: torch.Tensor,
num_physical_experts: int, num_physical_experts: int,
num_local_physical_experts: int, num_local_physical_experts: int,
num_groups: int, num_groups: Optional[int],
num_nodes: int, num_nodes: int,
phase: Literal["prefill", "decode"], phase: Literal["prefill", "decode", "null"],
): ):
if phase == "prefill": if (
(phase == "prefill")
and (num_groups is not None)
and (num_groups % num_nodes == 0)
):
return prefill_rebalance_experts( return prefill_rebalance_experts(
tokens_per_expert=tokens_per_expert, tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts, num_physical_experts=num_physical_experts,
...@@ -269,10 +273,8 @@ def rebalance_experts( ...@@ -269,10 +273,8 @@ def rebalance_experts(
num_groups=num_groups, num_groups=num_groups,
num_nodes=num_nodes, num_nodes=num_nodes,
) )
if phase == "decode":
return decode_rebalance_experts( return decode_rebalance_experts(
tokens_per_expert=tokens_per_expert, tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts, num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts, num_local_physical_experts=num_local_physical_experts,
) )
raise NotImplementedError
...@@ -135,10 +135,6 @@ class ExpertLocationMetadata: ...@@ -135,10 +135,6 @@ class ExpertLocationMetadata:
model_config_for_expert_location = common["model_config_for_expert_location"] model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"] num_physical_experts = common["num_physical_experts"]
phase = server_args.disaggregation_mode
if phase == "null" or model_config_for_expert_location.num_groups is None:
phase = "decode"
physical_to_logical_map, logical_to_all_physical_map, expert_count = ( physical_to_logical_map, logical_to_all_physical_map, expert_count = (
deepseek_eplb.rebalance_experts( deepseek_eplb.rebalance_experts(
tokens_per_expert=logical_count, tokens_per_expert=logical_count,
...@@ -146,7 +142,7 @@ class ExpertLocationMetadata: ...@@ -146,7 +142,7 @@ class ExpertLocationMetadata:
num_local_physical_experts=num_physical_experts // common["ep_size"], num_local_physical_experts=num_physical_experts // common["ep_size"],
num_groups=model_config_for_expert_location.num_groups, num_groups=model_config_for_expert_location.num_groups,
num_nodes=server_args.nnodes, num_nodes=server_args.nnodes,
phase=phase, phase=server_args.disaggregation_mode,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment