Unverified Commit 1eb61ab3 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[Refactor] EPLB rebalance algo to NumPy (#30697)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
parent 3d962d72
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest import pytest
import torch import torch
...@@ -312,9 +313,9 @@ if __name__ == "__main__": ...@@ -312,9 +313,9 @@ if __name__ == "__main__":
test_basic_rebalance() test_basic_rebalance()
def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: def _make_phy_replicas_idx_from_phy2log(phy2log: np.ndarray) -> np.ndarray:
"""Create replicas indices mapping from phy2log""" """Create replicas indices mapping from phy2log."""
pr = torch.zeros_like(phy2log) pr = np.zeros_like(phy2log, dtype=np.int64)
for layer in range(phy2log.shape[0]): for layer in range(phy2log.shape[0]):
seen: dict[int, int] = {} seen: dict[int, int] = {}
row = phy2log[layer].tolist() row = phy2log[layer].tolist()
...@@ -326,11 +327,11 @@ def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: ...@@ -326,11 +327,11 @@ def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
def _validate_intragpu_rearrangement( def _validate_intragpu_rearrangement(
old_global_expert_indices: torch.Tensor, old_global_expert_indices: np.ndarray,
new_phy2log: torch.Tensor, new_phy2log: np.ndarray,
new_phy_replicas_idx: torch.Tensor, new_phy_replicas_idx: np.ndarray,
post_phy2log: torch.Tensor, post_phy2log: np.ndarray,
post_phy_replicas_idx: torch.Tensor, post_phy_replicas_idx: np.ndarray,
num_ranks: int, num_ranks: int,
slots_per_gpu: int, slots_per_gpu: int,
): ):
...@@ -345,7 +346,7 @@ def _validate_intragpu_rearrangement( ...@@ -345,7 +346,7 @@ def _validate_intragpu_rearrangement(
post_rnk = post_phy_replicas_idx[0, start:end] post_rnk = post_phy_replicas_idx[0, start:end]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost # Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor): def sorted_pairs(seg, rnk):
pairs = list(zip(seg.tolist(), rnk.tolist())) pairs = list(zip(seg.tolist(), rnk.tolist()))
pairs.sort() pairs.sort()
return pairs return pairs
...@@ -386,8 +387,8 @@ def _validate_intragpu_rearrangement( ...@@ -386,8 +387,8 @@ def _validate_intragpu_rearrangement(
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3] # GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2, 2,
4, 4,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]), np.array([[0, 1, 2, 3, 4, 5, 6, 7]]),
torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]), np.array([[1, 5, 0, 4, 6, 2, 7, 3]]),
id="simple", id="simple",
), ),
pytest.param( pytest.param(
...@@ -401,8 +402,8 @@ def _validate_intragpu_rearrangement( ...@@ -401,8 +402,8 @@ def _validate_intragpu_rearrangement(
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated) # GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2, 2,
5, 5,
torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]), np.array([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]), np.array([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
id="duplicates", id="duplicates",
), ),
pytest.param( pytest.param(
...@@ -418,8 +419,8 @@ def _validate_intragpu_rearrangement( ...@@ -418,8 +419,8 @@ def _validate_intragpu_rearrangement(
# GPU2 new -> [1, 2, 3, 0] # GPU2 new -> [1, 2, 3, 0]
3, 3,
4, 4,
torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]), np.array([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]), np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
id="skewed_expert", id="skewed_expert",
), ),
], ],
......
...@@ -311,7 +311,7 @@ def _test_async_transfer_layer_without_mtp_worker( ...@@ -311,7 +311,7 @@ def _test_async_transfer_layer_without_mtp_worker(
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
recv_metadata=recv_metadata, recv_metadata=recv_metadata,
new_indices=new_indices_cpu[layer_idx], new_indices=new_indices_cpu[layer_idx].numpy(),
ep_rank=ep_rank, ep_rank=ep_rank,
) )
......
...@@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy ...@@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy
class DefaultEplbPolicy(AbstractEplbPolicy): class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod @classmethod
def balanced_packing( def balanced_packing(
cls, weight: torch.Tensor, num_packs: int cls, weight: np.ndarray, num_packs: int
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Pack n weighted objects to m packs, such that each bin contains exactly Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible. n/m objects and the weights of all packs are as balanced as possible.
...@@ -39,50 +39,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -39,50 +39,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_groups % num_packs == 0 assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1: if groups_per_pack == 1:
pack_index = torch.arange( pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
weight.size(-1), dtype=torch.int64, device=device rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack return pack_index, rank_in_pack
weight_np = weight.cpu().numpy()
# Sort and get indices in decending order # Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1) indices = np.argsort(-weight, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
pack_index_np[i, group] = pack pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device) # Run the packing algorithm
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) for layer_idx in range(num_layers):
weights_row = pack_weights[layer_idx]
items_row = pack_items[layer_idx]
for group in indices[layer_idx]:
# Pick the lightest pack; full packs are masked out by inf.
pack = int(np.argmin(weights_row))
pack_index[layer_idx, group] = pack
rank_in_pack[layer_idx, group] = items_row[pack]
weights_row[pack] += weight[layer_idx, group]
items_row[pack] += 1
if items_row[pack] == groups_per_pack:
# Mark as unavailable for future selections.
weights_row[pack] = np.inf
return pack_index, rank_in_pack return pack_index, rank_in_pack
@classmethod @classmethod
def replicate_experts( def replicate_experts(
cls, weight: torch.Tensor, num_phy: int cls, weight: np.ndarray, num_phy: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
Replicate `num_log` experts to `num_phy` replicas, such that the maximum Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized. load of all replicas is minimized.
...@@ -99,13 +92,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -99,13 +92,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
n, num_log = weight.shape n, num_log = weight.shape
num_redundant = num_phy - num_log num_redundant = num_phy - num_log
assert num_redundant >= 0 assert num_redundant >= 0
device = weight.device phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) replica_idx = np.zeros((n, num_phy), dtype=np.int64)
replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device) logcnt = np.ones((n, num_log), dtype=np.int64)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) arangen = np.arange(n, dtype=np.int64)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy): for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices redundant_indices = np.argmax(weight / logcnt, axis=-1)
phy2log[:, i] = redundant_indices phy2log[:, i] = redundant_indices
replica_idx[:, i] = logcnt[arangen, redundant_indices] replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1 logcnt[arangen, redundant_indices] += 1
...@@ -114,12 +106,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -114,12 +106,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod @classmethod
def rebalance_experts_hierarchical( def rebalance_experts_hierarchical(
cls, cls,
weight: torch.Tensor, weight: np.ndarray,
num_physical_experts: int, num_physical_experts: int,
num_groups: int, num_groups: int,
num_nodes: int, num_nodes: int,
num_gpus: int, num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
Parameters: Parameters:
weight: [num_moe_layers, num_logical_experts] weight: [num_moe_layers, num_logical_experts]
...@@ -146,35 +138,33 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -146,35 +138,33 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_physical_experts % num_gpus == 0 assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor: def inverse(perm: np.ndarray) -> np.ndarray:
inv = torch.empty_like(perm) inv = np.empty_like(perm)
inv.scatter_( row_idx = np.arange(perm.shape[0])[:, None]
1, col_idx = np.arange(perm.shape[1], dtype=np.int64)
perm, inv[row_idx, perm] = col_idx
torch.arange(
perm.size(1), dtype=torch.int64, device=perm.device
).expand(perm.shape),
)
return inv return inv
# Step 1: pack groups to nodes # Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
axis=-1
)
group_pack_index, group_rank_in_pack = cls.balanced_packing( group_pack_index, group_rank_in_pack = cls.balanced_packing(
tokens_per_group, num_nodes tokens_per_group, num_nodes
) )
# Map each logical expert into a node-local ordering based on packed groups.
log2mlog = ( log2mlog = (
( (
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size (group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
).unsqueeze(-1) * group_size
+ torch.arange(
group_size, dtype=torch.int64, device=group_pack_index.device
) )
).flatten(-2) + np.arange(group_size, dtype=np.int64)
).reshape(num_layers, num_logical_experts)
mlog2log = inverse(log2mlog) mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes # Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes] # Reorder weights into the node-local layout so replication is done per node.
tokens_per_mlog = weight.gather(-1, mlog2log).view( tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
-1, num_logical_experts // num_nodes -1, num_logical_experts // num_nodes
) )
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts( phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
...@@ -182,39 +172,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -182,39 +172,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
) )
# Step 3: pack physical_experts to GPUs # Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes] # Effective per-physical load = logical load divided by replica count.
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
pack_index, rank_in_pack = cls.balanced_packing( pack_index, rank_in_pack = cls.balanced_packing(
tokens_per_phy, num_gpus // num_nodes tokens_per_phy, num_gpus // num_nodes
) )
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy) pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather( # Reorder node-local logical indices into the post-packing physical order.
-1, pphy2phy pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = ( pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1) pphy2mlog.reshape(num_layers, num_nodes, -1)
+ torch.arange( + np.arange(
0, 0,
num_logical_experts, num_logical_experts,
num_logical_experts // num_nodes, num_logical_experts // num_nodes,
device=group_pack_index.device, dtype=np.int64,
).view(1, -1, 1) )[None, :, None]
).flatten(-2) ).reshape(num_layers, -1)
pphy2log = mlog2log.gather(-1, pphy2mlog) # Map node-local logical indices back to global logical expert ids.
pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1) pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) # Reorder replica ranks to the post-packing physical ordering.
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
num_layers, -1
)
# Convert replica counts back to the original logical ordering.
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
return pphy2log, pphy_replicas_idx, logcnt return pphy2log, pphy_replicas_idx, logcnt
@classmethod @classmethod
def preserve_intragpu_slots( def preserve_intragpu_slots(
cls, cls,
phy2log: torch.Tensor, phy2log: np.ndarray,
phy_replicas_idx: torch.Tensor, phy_replicas_idx: np.ndarray,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor, old_phy2log: np.ndarray,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
Reorder the new mapping per GPU so that experts that remain on the same GPU Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU keep their previous slot positions when possible. Incoming experts to that GPU
...@@ -222,29 +216,24 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -222,29 +216,24 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
is unchanged and the slots per GPU remain the same between is unchanged and the slots per GPU remain the same between
the old and new mappings. the old and new mappings.
""" """
device = phy2log.device
num_phy_experts = phy2log.shape[1] num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0: if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx return phy2log, phy_replicas_idx
# Move to CPU and convert to NumPy for processing # Move to CPU and convert to NumPy for processing
new_phy2log_np = phy2log.cpu().numpy()
replicas_idx_np = phy_replicas_idx.cpu().numpy()
old_phy2log_np = old_global_expert_indices.cpu().numpy()
slots_per_gpu = num_phy_experts // num_ranks slots_per_gpu = num_phy_experts // num_ranks
num_layers = new_phy2log_np.shape[0] num_layers = phy2log.shape[0]
post_phy2log_np = new_phy2log_np.copy() post_phy2log = phy2log.copy()
post_phy_replicas_idx_np = replicas_idx_np.copy() post_phy_replicas_idx = phy_replicas_idx.copy()
for gpu_idx in range(num_ranks): for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu end = start + slots_per_gpu
# Experts across all layers for this GPU # Experts across all layers for this GPU
old_local = old_phy2log_np[:, start:end] # [layers, slots] old_local = old_phy2log[:, start:end] # [layers, slots]
new_local = new_phy2log_np[:, start:end] # [layers, slots] new_local = phy2log[:, start:end] # [layers, slots]
new_ridx = replicas_idx_np[:, start:end] # [layers, slots] new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
...@@ -261,12 +250,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -261,12 +250,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
first_idx = np.argmax(matches, axis=1) first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0] layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices] matched_new_positions = first_idx[layer_indices]
post_phy2log_np[layer_indices, start + slot_idx] = new_local[ post_phy2log[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
layer_indices, matched_new_positions layer_indices, matched_new_positions
] ]
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
new_ridx[layer_indices, matched_new_positions]
)
used_new_indices[layer_indices, matched_new_positions] = True used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, slot_idx] = True preserved_positions[layer_indices, slot_idx] = True
...@@ -295,16 +284,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -295,16 +284,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
continue continue
src_pos = remaining_indices[layer_idx, :k] src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k] dst_pos = fill_indices[layer_idx, :k]
post_phy2log_np[layer_idx, start + dst_pos] = new_local[ post_phy2log[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos layer_idx, src_pos
] ]
post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[ post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos layer_idx, src_pos
] ]
# Convert back to torch and move to original device
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
return post_phy2log, post_phy_replicas_idx return post_phy2log, post_phy_replicas_idx
@classmethod @classmethod
...@@ -340,40 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ...@@ -340,40 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
logcnt: [layers, num_logical_experts], number of logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert physical replicas for each logical expert
""" """
device = weight.device
num_layers, num_logical_experts = weight.shape num_layers, num_logical_experts = weight.shape
weight = weight.float() weight_np = weight.float().cpu().numpy()
old_phy2log_np = (
old_global_expert_indices.cpu().numpy()
if old_global_expert_indices is not None
else None
)
if num_groups % num_nodes == 0: if num_groups % num_nodes == 0:
# use hierarchical load-balance policy # use hierarchical load-balance policy
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( phy2log_np, phy_replicas_idx_np, logcnt_np = (
weight, num_replicas, num_groups, num_nodes, num_ranks cls.rebalance_experts_hierarchical(
weight_np, num_replicas, num_groups, num_nodes, num_ranks
)
) )
else: else:
# use global load-balance policy # use global load-balance policy
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( phy2log_np, phy_replicas_idx_np, logcnt_np = (
weight, num_replicas, 1, 1, num_ranks cls.rebalance_experts_hierarchical(
weight_np, num_replicas, 1, 1, num_ranks
)
) )
# Optional postprocessing to preserve slots for experts moving # Optional postprocessing to preserve slots for experts moving
# within the same GPU # within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged. # Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move # Helps to avoid unnecessary weight copying when experts move
# within the same GPU. # within the same GPU.
if old_global_expert_indices is not None: if old_global_expert_indices is not None:
phy2log, phy_replicas_idx = cls.preserve_intragpu_slots( phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
) )
num_redundant_experts = num_replicas - num_logical_experts num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1 maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full( log2phy_np = np.full(
(num_layers, num_logical_experts, maxlogcnt), (num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
-1,
dtype=torch.int64,
device=logcnt.device,
) )
log2phy.view(num_layers, -1).scatter_( layer_indices = np.arange(num_layers)[:, None]
-1, replica_indices = np.tile(
phy2log * maxlogcnt + phy_replicas_idx, np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
) )
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
phy2log = torch.from_numpy(phy2log_np).to(device)
log2phy = torch.from_numpy(log2phy_np).to(device)
logcnt = torch.from_numpy(logcnt_np).to(device)
return phy2log, log2phy, logcnt return phy2log, log2phy, logcnt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment