Unverified Commit 6170d47d authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[EPLB] Optimize EPLB with numpy (#29499)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 0ada960a
...@@ -310,3 +310,143 @@ if __name__ == "__main__": ...@@ -310,3 +310,143 @@ if __name__ == "__main__":
print(phy2log) print(phy2log)
test_basic_rebalance() test_basic_rebalance()
def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
"""Create replicas indices mapping from phy2log"""
pr = torch.zeros_like(phy2log)
for layer in range(phy2log.shape[0]):
seen: dict[int, int] = {}
row = phy2log[layer].tolist()
for i, expert in enumerate(row):
r = seen.get(expert, 0)
pr[layer, i] = r
seen[expert] = r + 1
return pr
def _validate_intragpu_rearrangement(
old_global_expert_indices: torch.Tensor,
new_phy2log: torch.Tensor,
new_phy_replicas_idx: torch.Tensor,
post_phy2log: torch.Tensor,
post_phy_replicas_idx: torch.Tensor,
num_ranks: int,
slots_per_gpu: int,
):
# Per-GPU checks
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
old_seg = old_global_expert_indices[0, start:end]
new_seg = new_phy2log[0, start:end]
new_rnk = new_phy_replicas_idx[0, start:end]
post_seg = post_phy2log[0, start:end]
post_rnk = post_phy_replicas_idx[0, start:end]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor):
pairs = list(zip(seg.tolist(), rnk.tolist()))
pairs.sort()
return pairs
assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), (
f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}"
)
# For experts that remain on the same GPU, the old slot is preserved
# for at least one occurrence; rank at that slot must be valid for that expert
old_list = old_seg.tolist()
new_list = new_seg.tolist()
post_list = post_seg.tolist()
remained = set(old_list) & set(new_list)
new_ranks_for_expert: dict[int, list[int]] = {}
for v, r in zip(new_list, new_rnk.tolist()):
new_ranks_for_expert.setdefault(v, []).append(r)
for expert in remained:
old_pos = old_list.index(expert)
assert post_list[old_pos] == expert, (
f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}"
)
# Rank at preserved slot must be one of the ranks
# the expert has in new mapping
assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], (
f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} "
"must come from new mapping"
)
@pytest.mark.parametrize(
"num_ranks, slots_per_gpu, old_phy2log, new_phy2log",
[
pytest.param(
# Setup: 2 GPUs, 4 slots each, 1 layer
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2,
4,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]),
torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]),
id="simple",
),
pytest.param(
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
# GPU1 -> [4, 5, 6, 1, 2]
# New mapping reorders within GPUs and moves some experts across GPUs,
# while still including duplicates:
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2,
5,
torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
id="duplicates",
),
pytest.param(
# Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 2, 3]
# GPU1 -> [0, 1, 2, 3]
# GPU2 -> [0, 1, 2, 3]
# New mapping decides to use one expert on 2 GPUs and shuffles
# experts on the third GPU,
# GPU0 new -> [0, 0, 0, 0]
# GPU1 new -> [0, 0, 0, 0]
# GPU2 new -> [1, 2, 3, 0]
3,
4,
torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
id="skewed_expert",
),
],
)
def test_preserve_intragpu_slots(
num_ranks: int,
slots_per_gpu: int,
old_phy2log: torch.Tensor,
new_phy2log: torch.Tensor,
):
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
)
# Shapes preserved
assert post_phy2log.shape == new_phy2log.shape
assert post_phy_replicas_idx.shape == phy_replicas_idx.shape
_validate_intragpu_rearrangement(
old_phy2log,
new_phy2log,
phy_replicas_idx,
post_phy2log,
post_phy_replicas_idx,
num_ranks,
slots_per_gpu,
)
...@@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker(
device, device,
old_indices, old_indices,
) )
old_indices_cpu = old_indices.cpu()
new_indices_cpu = new_indices.cpu()
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device) cuda_stream = torch.cuda.Stream(device=device)
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run( is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer( transfer_layer(
old_global_expert_indices=old_indices, old_global_expert_indices=old_indices_cpu,
new_global_expert_indices=new_indices, new_global_expert_indices=new_indices_cpu,
expert_weights=expert_weights, expert_weights=expert_weights,
expert_weights_buffer=expert_buffer, expert_weights_buffer=expert_buffer,
ep_group=ep_group, ep_group=ep_group,
...@@ -302,16 +304,15 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -302,16 +304,15 @@ def _test_async_transfer_layer_without_mtp_worker(
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
) )
) )
cuda_stream.synchronize() cuda_stream.synchronize()
move_from_buffer( move_from_buffer(
expert_weights=expert_weights[layer_idx], expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer, expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc, recv_metadata=recv_metadata,
new_indices=new_indices[layer_idx].tolist(), new_indices=new_indices_cpu[layer_idx],
ep_group=ep_group, ep_rank=ep_rank,
) )
verify_expert_weights_after_shuffle( verify_expert_weights_after_shuffle(
......
...@@ -69,6 +69,10 @@ class EPLBConfig: ...@@ -69,6 +69,10 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism. Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead. This is turned off by default since it will cause communication overhead.
""" """
log_balancedness_interval: int = 1
"""
Interval for logging the balancedness.
"""
use_async: bool = False use_async: bool = False
""" """
Whether to use non-blocking EPLB. Whether to use non-blocking EPLB.
...@@ -77,6 +81,14 @@ class EPLBConfig: ...@@ -77,6 +81,14 @@ class EPLBConfig:
policy: EPLBPolicyOption = "default" policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB).""" """The policy type for expert parallel load balancing (EPLB)."""
@model_validator(mode="after")
def _validate_eplb_config(self) -> Self:
if self.use_async and self.policy != "default":
raise ValueError("Async EPLB is only supported with the default policy.")
if self.log_balancedness and self.log_balancedness_interval <= 0:
raise ValueError("log_balancedness_interval must be greater than 0.")
return self
@config @config
@dataclass @dataclass
......
...@@ -89,7 +89,7 @@ async def transfer_run_periodically( ...@@ -89,7 +89,7 @@ async def transfer_run_periodically(
( (
model_state.is_unchanged, model_state.is_unchanged,
model_state.is_received_locally, model_state.is_received_locally,
model_state.experts_recv_loc, model_state.recv_metadata,
) = await transfer_layer( ) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map, old_global_expert_indices=model_state.physical_to_logical_map,
new_global_expert_indices=model_state.new_physical_to_logical_map, new_global_expert_indices=model_state.new_physical_to_logical_map,
......
...@@ -27,10 +27,10 @@ physical experts. ...@@ -27,10 +27,10 @@ physical experts.
""" """
import threading import threading
import time
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np
import torch import torch
from torch.distributed import ProcessGroup, all_reduce from torch.distributed import ProcessGroup, all_reduce
...@@ -46,7 +46,11 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts ...@@ -46,7 +46,11 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts
from .async_worker import start_async_worker from .async_worker import start_async_worker
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace from .rebalance_execute import (
RecvMetadata,
move_from_buffer,
rearrange_expert_weights_inplace,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -164,20 +168,19 @@ class EplbModelState: ...@@ -164,20 +168,19 @@ class EplbModelState:
""" """
Whether the async EPLB needs to poll peers for buffer readiness. Whether the async EPLB needs to poll peers for buffer readiness.
""" """
is_unchanged: list[bool] is_unchanged: np.ndarray
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer. The size is same as the num of physical experts in the current layer.
""" """
is_received_locally: list[bool] is_received_locally: np.ndarray
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer. The size is same as the num of physical experts in the current layer.
""" """
experts_recv_loc: dict[int, int] recv_metadata: RecvMetadata
""" """
intermediate variable between `move_to_buffer` and `move_to_workspace`. intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
""" """
is_async_enabled: bool is_async_enabled: bool
""" """
...@@ -507,9 +510,14 @@ class EplbState: ...@@ -507,9 +510,14 @@ class EplbState:
layer_to_transfer=0, layer_to_transfer=0,
rebalanced=False, rebalanced=False,
pending_global_ready_check=False, pending_global_ready_check=False,
is_unchanged=[], is_unchanged=np.array([]),
is_received_locally=[], is_received_locally=np.array([]),
experts_recv_loc={}, recv_metadata=RecvMetadata(
recv_primary_mask=np.array([]),
recv_count=0,
recv_expert_ids=np.array([]),
recv_dst_rows=np.array([]),
),
is_async_enabled=self.is_async, is_async_enabled=self.is_async,
cuda_device_index=self.cuda_device_index, cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map, new_physical_to_logical_map=new_physical_to_logical_map,
...@@ -553,7 +561,12 @@ class EplbState: ...@@ -553,7 +561,12 @@ class EplbState:
for eplb_model_state in self.model_states.values(): for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_pass.zero_() eplb_model_state.expert_load_pass.zero_()
if log_stats: if (
log_stats
and self.expert_rearrangement_step
% self.parallel_config.eplb_config.log_balancedness_interval
== 0
):
# Sync the expert load pass for each model (main and drafter). # Sync the expert load pass for each model (main and drafter).
# expert_load_pass: (num_moe_layers, num_physical_experts) # expert_load_pass: (num_moe_layers, num_physical_experts)
expert_load_pass_list = self._sync_load_pass() expert_load_pass_list = self._sync_load_pass()
...@@ -586,12 +599,15 @@ class EplbState: ...@@ -586,12 +599,15 @@ class EplbState:
if ep_group.rank() == 0: if ep_group.rank() == 0:
logger.info( logger.info(
"EPLB step: %d for model %s: avg_tokens=%.2f, " "EPLB step: %d for model %s: avg_tokens=%.2f, "
"max_tokens=%d, balancedness=%.4f", "max_tokens=%d, balancedness=%.4f, "
"steps until the next rearrangement: %d",
self.expert_rearrangement_step, self.expert_rearrangement_step,
eplb_model_state.model_name, eplb_model_state.model_name,
avg_tokens, avg_tokens,
max_tokens, max_tokens,
balancedness, balancedness,
self.expert_rearrangement_step_interval
- self.expert_rearrangement_step,
) )
# Update the expert load sliding window # Update the expert load sliding window
...@@ -684,11 +700,14 @@ class EplbState: ...@@ -684,11 +700,14 @@ class EplbState:
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
ep_rank = ep_group.rank() ep_rank = ep_group.rank()
time_start = None start_event = None
end_event = None
is_main_rank = ep_rank == 0 is_main_rank = ep_rank == 0
if is_main_rank: if is_main_rank:
torch.cuda.synchronize() if not self.is_async or is_profile:
time_start = time.perf_counter() start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
logger.info( logger.info(
"Rearranging experts %s %s...", "Rearranging experts %s %s...",
"(async mode)" if self.is_async else "sync mode", "(async mode)" if self.is_async else "sync mode",
...@@ -800,6 +819,7 @@ class EplbState: ...@@ -800,6 +819,7 @@ class EplbState:
num_groups, num_groups,
num_nodes, num_nodes,
num_gpus, num_gpus,
eplb_model_state.physical_to_logical_map,
) )
if not eplb_model_state.is_async_enabled or is_profile: if not eplb_model_state.is_async_enabled or is_profile:
...@@ -848,17 +868,17 @@ class EplbState: ...@@ -848,17 +868,17 @@ class EplbState:
new_logical_replica_count new_logical_replica_count
) )
if is_main_rank: if is_main_rank:
assert time_start is not None assert start_event is not None
torch.cuda.synchronize() assert end_event is not None
time_end = time.perf_counter() end_event.record()
end_event.synchronize()
gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0
logger.info( logger.info(
"Rearranged experts%sin %.2f seconds.", "Rearranged experts %s in %.2f s.",
" (profile) " if is_profile else " ", " (profile) " if is_profile else " ",
time_end - time_start, gpu_elapsed,
) )
else: else:
device = eplb_model_state.physical_to_logical_map.device
new_physical = new_physical_to_logical_map.to(device)
max_slots = eplb_model_state.logical_to_physical_map.shape[-1] max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad( padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map, new_logical_to_physical_map,
...@@ -869,7 +889,10 @@ class EplbState: ...@@ -869,7 +889,10 @@ class EplbState:
eplb_model_state.logical_replica_count.device eplb_model_state.logical_replica_count.device
) )
eplb_model_state.new_physical_to_logical_map = new_physical # Move map to cpu in advance
eplb_model_state.new_physical_to_logical_map = (
new_physical_to_logical_map.cpu()
)
eplb_model_state.new_logical_to_physical_map = padded_logical eplb_model_state.new_logical_to_physical_map = padded_logical
eplb_model_state.new_logical_replica_count = new_replica eplb_model_state.new_logical_replica_count = new_replica
...@@ -968,25 +991,30 @@ class EplbState: ...@@ -968,25 +991,30 @@ class EplbState:
stream = torch.cuda.current_stream(device=device_index) stream = torch.cuda.current_stream(device=device_index)
stream.wait_event(model_state.buffer_ready_event) stream.wait_event(model_state.buffer_ready_event)
model_state.buffer_ready_event = None model_state.buffer_ready_event = None
move_from_buffer( expert_weights = model_state.model.expert_weights[
expert_weights=model_state.model.expert_weights[
model_state.layer_to_transfer model_state.layer_to_transfer
], ]
expert_weights_buffer=model_state.expert_buffer, expert_weights_buffer = model_state.expert_buffer
new_indices = (
model_state.new_physical_to_logical_map[model_state.layer_to_transfer]
.cpu()
.numpy()
)
move_from_buffer(
expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer,
is_unchanged=model_state.is_unchanged, is_unchanged=model_state.is_unchanged,
is_received_locally=model_state.is_received_locally, is_received_locally=model_state.is_received_locally,
experts_recv_loc=model_state.experts_recv_loc, recv_metadata=model_state.recv_metadata,
new_indices=model_state.new_physical_to_logical_map[ new_indices=new_indices,
model_state.layer_to_transfer ep_rank=ep_group.rank(),
].tolist(),
ep_group=ep_group,
) )
transferred_layer = model_state.layer_to_transfer transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer) self._update_layer_mapping_from_new(model_state, transferred_layer)
# After the main thread consumes, advance layer_to_transfer # After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1 model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0 model_state.ep_buffer_ready = 0
logger.info( logger.debug(
"model %s successfully move_to_workspace layer %d", "model %s successfully move_to_workspace layer %d",
model_state.model_name, model_state.model_name,
transferred_layer, transferred_layer,
......
...@@ -16,6 +16,7 @@ class AbstractEplbPolicy(ABC): ...@@ -16,6 +16,7 @@ class AbstractEplbPolicy(ABC):
num_groups: int, num_groups: int,
num_nodes: int, num_nodes: int,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Entry point for expert-parallelism load balancer. Entry point for expert-parallelism load balancer.
...@@ -28,7 +29,9 @@ class AbstractEplbPolicy(ABC): ...@@ -28,7 +29,9 @@ class AbstractEplbPolicy(ABC):
num_groups: number of expert groups num_groups: number of expert groups
num_nodes: number of server nodes num_nodes: number of server nodes
num_ranks: number of ranks, must be a multiple of `num_nodes` num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns: Returns:
physical_to_logical_map: [layers, num_replicas], the expert physical_to_logical_map: [layers, num_replicas], the expert
index of each replica index of each replica
......
...@@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [X, num_phy], logical expert id of each physical expert phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank replica_idx: [X, num_phy], the index of the replica for each logical expert
logcnt: [X, num_log], number of replicas for each logical expert logcnt: [X, num_log], number of replicas for each logical expert
""" """
n, num_log = weight.shape n, num_log = weight.shape
...@@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_redundant >= 0 assert num_redundant >= 0
device = weight.device device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy): for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices] replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1 logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt return phy2log, replica_idx, logcnt
@classmethod @classmethod
def rebalance_experts_hierarchical( def rebalance_experts_hierarchical(
...@@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [layers, num_replicas], the expert phy2log: [layers, num_replicas], the expert
index of each replica index of each replica
log2phy: [layers, num_logical_experts, X], pphy_replicas_idx: [layers, num_logical_experts, X],
the replica indices for each expert the replica indices for each expert
logcnt: [layers, num_logical_experts], number of logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert physical replicas for each logical expert
...@@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
tokens_per_mlog = weight.gather(-1, mlog2log).view( tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes -1, num_logical_experts // num_nodes
) )
phy2mlog, phyrank, mlogcnt = cls.replicate_experts( phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes tokens_per_mlog, num_physical_experts // num_nodes
) )
...@@ -203,9 +203,109 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -203,9 +203,109 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
).view(1, -1, 1) ).view(1, -1, 1)
).flatten(-2) ).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog) pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt return pphy2log, pphy_replicas_idx, logcnt
@classmethod
def preserve_intragpu_slots(
cls,
phy2log: torch.Tensor,
phy_replicas_idx: torch.Tensor,
num_ranks: int,
old_global_expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
fill any remaining available slots. This is applied only when the number of GPUs
is unchanged and the slots per GPU remain the same between
the old and new mappings.
"""
device = phy2log.device
num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx
# Move to CPU and convert to NumPy for processing
new_phy2log_np = phy2log.cpu().numpy()
replicas_idx_np = phy_replicas_idx.cpu().numpy()
old_phy2log_np = old_global_expert_indices.cpu().numpy()
slots_per_gpu = num_phy_experts // num_ranks
num_layers = new_phy2log_np.shape[0]
post_phy2log_np = new_phy2log_np.copy()
post_phy_replicas_idx_np = replicas_idx_np.copy()
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
# Experts across all layers for this GPU
old_local = old_phy2log_np[:, start:end] # [layers, slots]
new_local = new_phy2log_np[:, start:end] # [layers, slots]
new_ridx = replicas_idx_np[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
# First pass: preserve same-logical experts in their previous slots
for slot_idx in range(slots_per_gpu):
# matches: [layers, slots], True where new local experts have
# the same logical value as the old from 'slot_idx' and not checked yet
matches = (new_local == old_local[:, slot_idx][:, None]) & (
~used_new_indices
)
has_any = matches.any(axis=1)
if np.any(has_any):
first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices]
post_phy2log_np[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
new_ridx[layer_indices, matched_new_positions]
)
used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, slot_idx] = True
# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [layers, slots]
fill_mask = ~preserved_positions # [layers, slots]
if remaining_mask.any() and fill_mask.any():
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
# Sentinel value for unavailable positions.
large = slots_per_gpu + 1
# Priorities: keep original index for available spots, set sentinel
# for unavailable; lower is earlier.
remaining_priority = np.where(remaining_mask, idx_base, large)
fill_priority = np.where(fill_mask, idx_base, large)
# Sort to get ordered indices of available src/dst positions per layer.
remaining_indices = np.argsort(remaining_priority, axis=1)
fill_indices = np.argsort(fill_priority, axis=1)
# Fill count per layer (cannot exceed either side).
remaining_counts = remaining_mask.sum(axis=1)
fill_counts = fill_mask.sum(axis=1)
take_counts = np.minimum(remaining_counts, fill_counts)
# Assign remaining new experts to remaining slots per layer.
for layer_idx in range(num_layers):
k = int(take_counts[layer_idx])
if k <= 0:
continue
src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k]
post_phy2log_np[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos
]
post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos
]
# Convert back to torch and move to original device
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
return post_phy2log, post_phy_replicas_idx
@classmethod @classmethod
def rebalance_experts( def rebalance_experts(
...@@ -215,6 +315,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -215,6 +315,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_groups: int, num_groups: int,
num_nodes: int, num_nodes: int,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Entry point for expert-parallelism load balancer. Entry point for expert-parallelism load balancer.
...@@ -228,7 +329,9 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -228,7 +329,9 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_nodes: number of server nodes, where the intra-node network num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster (e.g, NVLink) is faster
num_ranks: number of ranks, must be a multiple of `num_nodes` num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns: Returns:
phy2log: [layers, num_replicas], the expert phy2log: [layers, num_replicas], the expert
index of each replica index of each replica
...@@ -241,14 +344,23 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -241,14 +344,23 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
weight = weight.float() weight = weight.float()
if num_groups % num_nodes == 0: if num_groups % num_nodes == 0:
# use hierarchical load-balance policy # use hierarchical load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks weight, num_replicas, num_groups, num_nodes, num_ranks
) )
else: else:
# use global load-balance policy # use global load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks weight, num_replicas, 1, 1, num_ranks
) )
# Optional postprocessing to preserve slots for experts moving
# within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
if old_global_expert_indices is not None:
phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
)
num_redundant_experts = num_replicas - num_logical_experts num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1 maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full( log2phy: torch.Tensor = torch.full(
...@@ -259,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -259,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
) )
log2phy.view(num_layers, -1).scatter_( log2phy.view(num_layers, -1).scatter_(
-1, -1,
phy2log * maxlogcnt + phyrank, phy2log * maxlogcnt + phy_replicas_idx,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1 num_layers, -1
), ),
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment