Unverified Commit cba1cdbc authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support DeepSeek EPLB algorithm with static distributions (#6387)

parent c471d39e
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from typing import Literal, Tuple
import torch
def pack_groups(tokens_per_group: torch.Tensor, num_nodes: int) -> torch.Tensor:
num_layers, num_groups = tokens_per_group.shape
assert num_groups % num_nodes == 0
groups_per_rank = num_groups // num_nodes
indices = tokens_per_group.float().sort(-1, descending=True).indices.cpu()
ret = torch.full_like(
tokens_per_group, fill_value=-1, dtype=torch.int64, device="cpu"
)
for layer in range(num_layers):
node_tokens = [0] * num_nodes
node_groups = [0] * num_nodes
for group in indices[layer]:
def key_func(rank: int) -> int:
if node_groups[rank] >= groups_per_rank:
return 1, 0
else:
return 0, node_tokens[rank]
rank = min(range(num_nodes), key=key_func)
assert node_groups[rank] < groups_per_rank
ret[layer, group] = rank * groups_per_rank + node_groups[rank]
node_tokens[rank] += tokens_per_group[layer, group]
node_groups[rank] += 1
return ret
def make_redundant_experts_chunkwise(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_physical_experts_per_chunk: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_steps, num_moe_layers, num_logical_experts = tokens_per_expert.shape
num_redundancy_experts = num_physical_experts - num_logical_experts
physical_to_logical_map = torch.empty(
num_moe_layers,
num_physical_experts,
dtype=torch.int,
device=tokens_per_expert.device,
)
logical_to_physical_map = torch.full(
(num_moe_layers, num_logical_experts, num_redundancy_experts + 1),
-1,
dtype=torch.int,
device=tokens_per_expert.device,
)
logical_count = torch.ones(
num_moe_layers,
num_logical_experts,
dtype=torch.int,
device=tokens_per_expert.device,
)
assert num_physical_experts % num_physical_experts_per_chunk == 0
num_chunks = num_physical_experts // num_physical_experts_per_chunk
assert num_logical_experts % num_chunks == 0
num_logical_experts_per_group = num_logical_experts // num_chunks
assert num_redundancy_experts % num_chunks == 0
num_redundancy_experts_per_group = num_redundancy_experts // num_chunks
arange_num_moe_layers_num_groups = torch.arange(
num_moe_layers * num_chunks, dtype=torch.int, device=tokens_per_expert.device
)
arange_num_logical_experts = torch.arange(
num_logical_experts, dtype=torch.int, device=tokens_per_expert.device
)
arange_num_logical_experts_per_group = torch.arange(
num_logical_experts_per_group, dtype=torch.int, device=tokens_per_expert.device
)
arange_num_groups = torch.arange(
num_chunks, dtype=torch.int, device=tokens_per_expert.device
)
physical_to_logical_map.view(
num_moe_layers, num_chunks, num_physical_experts_per_chunk
)[:, :, :num_logical_experts_per_group] = arange_num_logical_experts.view(
num_chunks, num_logical_experts_per_group
)
logical_to_physical_map[:, :, 0] = (
arange_num_logical_experts_per_group.expand(
num_chunks, num_logical_experts_per_group
)
+ arange_num_groups[:, None] * num_physical_experts_per_chunk
).view(num_logical_experts)
tokens_per_expert_all_diff = tokens_per_expert + arange_num_logical_experts * 1e-4
for i in range(num_redundancy_experts_per_group):
score = (
tokens_per_expert_all_diff / logical_count
) # NOTE: Values in score must be different from each other
score1 = tokens_per_expert / (logical_count + 1)
score = score.view(
num_steps, num_moe_layers, num_chunks, num_logical_experts_per_group
)
score1 = score1.view_as(score)
values, indices = score.max(-1, keepdim=True)
values = values.expand_as(score).contiguous()
score.scatter_(-1, indices, score1.gather(-1, indices))
values.scatter_(-1, indices, score.max(-1, keepdim=True).values)
redundancy_indices = values.sum(0).argmin(-1)
physical_to_logical_map.view(
num_moe_layers, num_chunks, num_physical_experts_per_chunk
)[:, :, num_logical_experts_per_group + i] = (
redundancy_indices + arange_num_groups * num_logical_experts_per_group
)
redundancy_count = (
logical_count.view(
num_moe_layers * num_chunks, num_logical_experts_per_group
)
.gather(-1, redundancy_indices.view(num_moe_layers * num_chunks, 1))
.squeeze(1)
)
physical_redundancy_indices = (
(
arange_num_groups * num_physical_experts_per_chunk
+ num_logical_experts_per_group
+ i
)
.expand(num_moe_layers, num_chunks)
.flatten()
)
logical_to_physical_map.view(
num_moe_layers * num_chunks,
num_logical_experts_per_group,
num_redundancy_experts + 1,
)[
arange_num_moe_layers_num_groups,
redundancy_indices.view(num_moe_layers * num_chunks),
redundancy_count,
] = physical_redundancy_indices
logical_count.view(num_moe_layers * num_chunks, num_logical_experts_per_group)[
arange_num_moe_layers_num_groups,
redundancy_indices.view(num_moe_layers * num_chunks),
] += 1
if num_local_physical_experts > 1:
# Load-balancing between GPUs
physical_to_logical_map_int64 = physical_to_logical_map.to(torch.int64)
counts = logical_count.gather(-1, physical_to_logical_map_int64)
score = tokens_per_expert.sum(0).gather(-1, physical_to_logical_map_int64)
score = score / counts
score = score.view(num_moe_layers, num_chunks, num_physical_experts_per_chunk)
indices = score.argsort(-1, descending=True)
indices += torch.arange(
0,
num_physical_experts,
num_physical_experts_per_chunk,
dtype=indices.dtype,
device=indices.device,
)[None, :, None]
assert num_physical_experts_per_chunk % num_local_physical_experts == 0
num_local_groups = num_physical_experts_per_chunk // num_local_physical_experts
indices = indices.view(
num_moe_layers, num_chunks, num_local_physical_experts, num_local_groups
)
indices[:, :, 1::2, :] = indices[:, :, 1::2, :].flip(-1)
indices = indices.transpose(2, 3)
indices = indices.reshape(num_moe_layers, num_physical_experts)
physical_to_logical_map = physical_to_logical_map.gather(-1, indices)
mask = logical_to_physical_map == -1
logical_to_physical_map[mask] = 0
logical_to_physical_map = (
indices.argsort(-1)
.gather(
-1, logical_to_physical_map.view(num_moe_layers, -1).to(torch.int64)
)
.view_as(logical_to_physical_map)
.to(torch.int)
)
logical_to_physical_map[mask] = -1
return physical_to_logical_map, logical_to_physical_map, logical_count
def decode_rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
):
return make_redundant_experts_chunkwise(
tokens_per_expert,
num_physical_experts,
num_local_physical_experts,
num_physical_experts,
)
def prefill_rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_groups: int,
num_nodes: int,
):
tokens_per_expert = tokens_per_expert.float().cpu()
num_steps, _, num_logical_experts = tokens_per_expert.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0, f"{num_groups=} {num_nodes=}"
tokens_per_group = tokens_per_expert.sum(0).unflatten(-1, (num_groups, -1)).sum(-1)
group_perm = pack_groups(
tokens_per_group, num_nodes
) # [num_moe_layers, num_groups] => [num_moe_layers, num_nodes]
# log2mlog [layers, #logexp] -> [layers, #logexp]
log2mlog = (
(group_perm * group_size).unsqueeze(-1)
+ torch.arange(group_size, dtype=torch.int64, device=group_perm.device)
).flatten(-2)
# mlog2log [layers, #logexp] -> [layers, #logexp], inverse of log2mlog
mlog2log = torch.empty_like(log2mlog)
arange = torch.arange(
num_logical_experts, dtype=torch.int64, device=mlog2log.device
)
mlog2log.scatter_(1, log2mlog, arange.expand(log2mlog.size(0), -1))
# tokens_per_mlog[i][j][k] = tokens_per_expert[i][j][mlog2log[j][k]]
tokens_per_mlog = tokens_per_expert.gather(
2, mlog2log.unsqueeze(0).expand(num_steps, -1, -1)
)
phy2mlog, mlog2phy, mlog_count = make_redundant_experts_chunkwise(
tokens_per_mlog,
num_physical_experts,
num_local_physical_experts,
num_physical_experts // num_nodes,
)
# phy2log[i][j] = mlog2log[i][phy2mlog[i][j]]
phy2log = mlog2log.gather(1, phy2mlog.to(torch.int64))
# mlog2phy: [num_moe_layers, num_logical_experts, ...]
# log2phy[i][j][k] = mlog2phy[i][log2mlog[i][j]][k]
log2phy = mlog2phy.gather(
1, log2mlog.unsqueeze(-1).expand(-1, -1, mlog2phy.size(-1)).to(torch.int64)
)
# log_count[i][j] = mlog_count[i][log2mlog[i][j]]
log_count = mlog_count.gather(1, log2mlog)
return phy2log, log2phy, log_count
def rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_groups: int,
num_nodes: int,
phase: Literal["prefill", "decode"],
):
if phase == "prefill":
return prefill_rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
)
if phase == "decode":
return decode_rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
raise NotImplementedError
...@@ -117,6 +117,41 @@ class ExpertLocationMetadata: ...@@ -117,6 +117,41 @@ class ExpertLocationMetadata:
logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map=logical_to_all_physical_map,
) )
@staticmethod
def init_by_eplb(
server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor
):
if not isinstance(logical_count, torch.Tensor):
logical_count = torch.tensor(logical_count)
if len(logical_count.shape) == 2:
logical_count = logical_count.unsqueeze(0)
logical_count = logical_count.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"]
phase = server_args.disaggregation_mode
if phase == "null":
phase = "decode"
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
deepseek_eplb.rebalance_experts(
tokens_per_expert=logical_count,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_physical_experts // common["ep_size"],
num_groups=model_config_for_expert_location.num_groups,
num_nodes=server_args.nnodes,
phase=phase,
)
)
return ExpertLocationMetadata._init_raw(
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
)
@staticmethod @staticmethod
def _init_common(server_args: ServerArgs, model_config: ModelConfig): def _init_common(server_args: ServerArgs, model_config: ModelConfig):
model_config_for_expert_location = ( model_config_for_expert_location = (
...@@ -272,14 +307,12 @@ def compute_initial_expert_location_metadata( ...@@ -272,14 +307,12 @@ def compute_initial_expert_location_metadata(
server_args, model_config, **data_dict server_args, model_config, **data_dict
) )
elif "logical_count" in data_dict: elif "logical_count" in data_dict:
# TODO pr-chain: enable this later logger.info(
raise NotImplementedError "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
# logger.info( )
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location" return ExpertLocationMetadata.init_by_eplb(
# ) server_args, model_config, logical_count=data_dict["logical_count"]
# return ExpertLocationMetadata.init_by_eplb( )
# server_args, model_config, logical_count=data_dict["logical_count"]
# )
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unknown init_expert_location format ({list(data_dict.keys())=})" f"Unknown init_expert_location format ({list(data_dict.keys())=})"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment