Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1eb61ab3
Unverified
Commit
1eb61ab3
authored
Jan 12, 2026
by
Ilya Markov
Committed by
GitHub
Jan 12, 2026
Browse files
[Refactor] EPLB rebalance algo to NumPy (#30697)
Signed-off-by:
ilmarkov
<
markovilya197@gmail.com
>
parent
3d962d72
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
128 additions
and
130 deletions
+128
-130
tests/distributed/test_eplb_algo.py
tests/distributed/test_eplb_algo.py
+16
-15
tests/distributed/test_eplb_execute.py
tests/distributed/test_eplb_execute.py
+1
-1
vllm/distributed/eplb/policy/default.py
vllm/distributed/eplb/policy/default.py
+111
-114
No files found.
tests/distributed/test_eplb_algo.py
View file @
1eb61ab3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
pytest
import
torch
...
...
@@ -312,9 +313,9 @@ if __name__ == "__main__":
test_basic_rebalance
()
def
_make_phy_replicas_idx_from_phy2log
(
phy2log
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Create replicas indices mapping from phy2log"""
pr
=
torch
.
zeros_like
(
phy2log
)
def
_make_phy_replicas_idx_from_phy2log
(
phy2log
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Create replicas indices mapping from phy2log
.
"""
pr
=
np
.
zeros_like
(
phy2log
,
dtype
=
np
.
int64
)
for
layer
in
range
(
phy2log
.
shape
[
0
]):
seen
:
dict
[
int
,
int
]
=
{}
row
=
phy2log
[
layer
].
tolist
()
...
...
@@ -326,11 +327,11 @@ def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
def
_validate_intragpu_rearrangement
(
old_global_expert_indices
:
torch
.
Tensor
,
new_phy2log
:
torch
.
Tensor
,
new_phy_replicas_idx
:
torch
.
Tensor
,
post_phy2log
:
torch
.
Tensor
,
post_phy_replicas_idx
:
torch
.
Tensor
,
old_global_expert_indices
:
np
.
ndarray
,
new_phy2log
:
np
.
ndarray
,
new_phy_replicas_idx
:
np
.
ndarray
,
post_phy2log
:
np
.
ndarray
,
post_phy_replicas_idx
:
np
.
ndarray
,
num_ranks
:
int
,
slots_per_gpu
:
int
,
):
...
...
@@ -345,7 +346,7 @@ def _validate_intragpu_rearrangement(
post_rnk
=
post_phy_replicas_idx
[
0
,
start
:
end
]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def
sorted_pairs
(
seg
:
torch
.
Tensor
,
rnk
:
torch
.
Tensor
):
def
sorted_pairs
(
seg
,
rnk
):
pairs
=
list
(
zip
(
seg
.
tolist
(),
rnk
.
tolist
()))
pairs
.
sort
()
return
pairs
...
...
@@ -386,8 +387,8 @@ def _validate_intragpu_rearrangement(
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2
,
4
,
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]]),
torch
.
tensor
([[
1
,
5
,
0
,
4
,
6
,
2
,
7
,
3
]]),
np
.
array
([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]]),
np
.
array
([[
1
,
5
,
0
,
4
,
6
,
2
,
7
,
3
]]),
id
=
"simple"
,
),
pytest
.
param
(
...
...
@@ -401,8 +402,8 @@ def _validate_intragpu_rearrangement(
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2
,
5
,
torch
.
tensor
([[
0
,
1
,
0
,
2
,
3
,
4
,
5
,
6
,
1
,
2
]]),
torch
.
tensor
([[
0
,
5
,
4
,
0
,
1
,
6
,
2
,
3
,
2
,
1
]]),
np
.
array
([[
0
,
1
,
0
,
2
,
3
,
4
,
5
,
6
,
1
,
2
]]),
np
.
array
([[
0
,
5
,
4
,
0
,
1
,
6
,
2
,
3
,
2
,
1
]]),
id
=
"duplicates"
,
),
pytest
.
param
(
...
...
@@ -418,8 +419,8 @@ def _validate_intragpu_rearrangement(
# GPU2 new -> [1, 2, 3, 0]
3
,
4
,
torch
.
tensor
([[
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
]]),
torch
.
tensor
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
]]),
np
.
array
([[
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
]]),
np
.
array
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
]]),
id
=
"skewed_expert"
,
),
],
...
...
tests/distributed/test_eplb_execute.py
View file @
1eb61ab3
...
...
@@ -311,7 +311,7 @@ def _test_async_transfer_layer_without_mtp_worker(
is_unchanged
=
is_unchanged
,
is_received_locally
=
is_received_locally
,
recv_metadata
=
recv_metadata
,
new_indices
=
new_indices_cpu
[
layer_idx
],
new_indices
=
new_indices_cpu
[
layer_idx
]
.
numpy
()
,
ep_rank
=
ep_rank
,
)
...
...
vllm/distributed/eplb/policy/default.py
View file @
1eb61ab3
...
...
@@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy
class
DefaultEplbPolicy
(
AbstractEplbPolicy
):
@
classmethod
def
balanced_packing
(
cls
,
weight
:
torch
.
Tensor
,
num_packs
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cls
,
weight
:
np
.
ndarray
,
num_packs
:
int
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
...
...
@@ -39,50 +39,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert
num_groups
%
num_packs
==
0
groups_per_pack
=
num_groups
//
num_packs
device
=
weight
.
device
if
groups_per_pack
==
1
:
pack_index
=
torch
.
arange
(
weight
.
size
(
-
1
),
dtype
=
torch
.
int64
,
device
=
device
).
expand
(
weight
.
shape
)
rank_in_pack
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int64
,
device
=
device
)
pack_index
=
np
.
tile
(
np
.
arange
(
num_groups
,
dtype
=
np
.
int64
),
(
num_layers
,
1
))
rank_in_pack
=
np
.
zeros_like
(
pack_index
,
dtype
=
np
.
int64
)
return
pack_index
,
rank_in_pack
weight_np
=
weight
.
cpu
().
numpy
()
# Sort and get indices in decending order
indices_np
=
np
.
argsort
(
-
weight_np
,
axis
=-
1
)
indices
=
np
.
argsort
(
-
weight
,
axis
=-
1
)
pack_index
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
rank_in_pack
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
pack_
index_np
=
np
.
full
((
num_layers
,
num_
groups
),
-
1
,
dtype
=
np
.
in
t64
)
rank_in_pack_np
=
np
.
full
((
num_layers
,
num_
groups
),
-
1
,
dtype
=
np
.
int64
)
pack_
weights
=
np
.
zeros
((
num_layers
,
num_
packs
)
,
dtype
=
np
.
floa
t64
)
pack_items
=
np
.
zeros
((
num_layers
,
num_
packs
)
,
dtype
=
np
.
int64
)
# Run the packing algorithm
for
i
in
range
(
num_layers
):
pack_weights
=
[
0.0
]
*
num_packs
pack_items
=
[
0
]
*
num_packs
for
group
in
indices_np
[
i
]:
# Find a pack with capacity that has the lowest weight
pack
=
min
(
(
j
for
j
in
range
(
num_packs
)
if
pack_items
[
j
]
<
groups_per_pack
),
key
=
pack_weights
.
__getitem__
,
)
for
layer_idx
in
range
(
num_layers
):
weights_row
=
pack_weights
[
layer_idx
]
items_row
=
pack_items
[
layer_idx
]
assert
pack_items
[
pack
]
<
groups_per_pack
pack_index_np
[
i
,
group
]
=
pack
rank_in_pack_np
[
i
,
group
]
=
pack_items
[
pack
]
pack_weights
[
pack
]
+=
weight_np
[
i
,
group
]
pack_items
[
pack
]
+=
1
for
group
in
indices
[
layer_idx
]:
# Pick the lightest pack; full packs are masked out by inf.
pack
=
int
(
np
.
argmin
(
weights_row
))
pack_index
=
torch
.
from_numpy
(
pack_index_np
).
to
(
device
)
rank_in_pack
=
torch
.
from_numpy
(
rank_in_pack_np
).
to
(
device
)
pack_index
[
layer_idx
,
group
]
=
pack
rank_in_pack
[
layer_idx
,
group
]
=
items_row
[
pack
]
weights_row
[
pack
]
+=
weight
[
layer_idx
,
group
]
items_row
[
pack
]
+=
1
if
items_row
[
pack
]
==
groups_per_pack
:
# Mark as unavailable for future selections.
weights_row
[
pack
]
=
np
.
inf
return
pack_index
,
rank_in_pack
@
classmethod
def
replicate_experts
(
cls
,
weight
:
torch
.
Tensor
,
num_phy
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
cls
,
weight
:
np
.
ndarray
,
num_phy
:
int
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
...
...
@@ -99,13 +92,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
n
,
num_log
=
weight
.
shape
num_redundant
=
num_phy
-
num_log
assert
num_redundant
>=
0
device
=
weight
.
device
phy2log
=
torch
.
arange
(
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
).
repeat
(
n
,
1
)
replica_idx
=
torch
.
zeros
(
n
,
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
)
logcnt
=
torch
.
ones
(
n
,
num_log
,
dtype
=
torch
.
int64
,
device
=
device
)
arangen
=
torch
.
arange
(
n
,
dtype
=
torch
.
int64
,
device
=
device
)
phy2log
=
np
.
tile
(
np
.
arange
(
num_phy
,
dtype
=
np
.
int64
),
(
n
,
1
))
replica_idx
=
np
.
zeros
((
n
,
num_phy
),
dtype
=
np
.
int64
)
logcnt
=
np
.
ones
((
n
,
num_log
),
dtype
=
np
.
int64
)
arangen
=
np
.
arange
(
n
,
dtype
=
np
.
int64
)
for
i
in
range
(
num_log
,
num_phy
):
redundant_indices
=
(
weight
/
logcnt
).
max
(
dim
=-
1
).
indices
redundant_indices
=
np
.
argmax
(
weight
/
logcnt
,
axis
=-
1
)
phy2log
[:,
i
]
=
redundant_indices
replica_idx
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
logcnt
[
arangen
,
redundant_indices
]
+=
1
...
...
@@ -114,12 +106,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
@
classmethod
def
rebalance_experts_hierarchical
(
cls
,
weight
:
torch
.
Tensor
,
weight
:
np
.
ndarray
,
num_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
...
...
@@ -146,35 +138,33 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert
num_physical_experts
%
num_gpus
==
0
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
def
inverse
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
inv
=
torch
.
empty_like
(
perm
)
inv
.
scatter_
(
1
,
perm
,
torch
.
arange
(
perm
.
size
(
1
),
dtype
=
torch
.
int64
,
device
=
perm
.
device
).
expand
(
perm
.
shape
),
)
def
inverse
(
perm
:
np
.
ndarray
)
->
np
.
ndarray
:
inv
=
np
.
empty_like
(
perm
)
row_idx
=
np
.
arange
(
perm
.
shape
[
0
])[:,
None
]
col_idx
=
np
.
arange
(
perm
.
shape
[
1
],
dtype
=
np
.
int64
)
inv
[
row_idx
,
perm
]
=
col_idx
return
inv
# Step 1: pack groups to nodes
tokens_per_group
=
weight
.
unflatten
(
-
1
,
(
num_groups
,
group_size
)).
sum
(
-
1
)
tokens_per_group
=
weight
.
reshape
(
num_layers
,
num_groups
,
group_size
).
sum
(
axis
=-
1
)
group_pack_index
,
group_rank_in_pack
=
cls
.
balanced_packing
(
tokens_per_group
,
num_nodes
)
# Map each logical expert into a node-local ordering based on packed groups.
log2mlog
=
(
(
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)
*
group_size
).
unsqueeze
(
-
1
)
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_pack_index
.
device
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)[...,
None
]
*
group_size
)
).
flatten
(
-
2
)
+
np
.
arange
(
group_size
,
dtype
=
np
.
int64
)
).
reshape
(
num_layers
,
num_logical_experts
)
mlog2log
=
inverse
(
log2mlog
)
# Step 2: construct redundant experts within nodes
#
[num_layers * num_nodes, num_logical_experts // num_
node
s]
tokens_per_mlog
=
weight
.
gather
(
-
1
,
mlog2log
).
view
(
#
Reorder weights into the node-local layout so replication is done per
node
.
tokens_per_mlog
=
np
.
take_along_axis
(
weight
,
mlog2log
,
axis
=
1
).
reshape
(
-
1
,
num_logical_experts
//
num_nodes
)
phy2mlog
,
replicas_idx
,
mlogcnt
=
cls
.
replicate_experts
(
...
...
@@ -182,39 +172,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
)
# Step 3: pack physical_experts to GPUs
#
[num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy
=
(
tokens_per_mlog
/
mlogcnt
).
gather
(
-
1
,
phy2mlog
)
#
Effective per-physical load = logical load divided by replica count.
tokens_per_phy
=
np
.
take_along_axis
(
tokens_per_mlog
/
mlogcnt
,
phy2mlog
,
axis
=
1
)
pack_index
,
rank_in_pack
=
cls
.
balanced_packing
(
tokens_per_phy
,
num_gpus
//
num_nodes
)
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
pphy2phy
=
inverse
(
phy2pphy
)
pphy2mlog
=
phy2mlog
.
gather
(
-
1
,
pphy2phy
)
# [num_layers * num_nodes, num_log_per_nodes]
# Reorder node-local logical indices into the post-packing physical order.
pphy2mlog
=
np
.
take_along_axis
(
phy2mlog
,
pphy2phy
,
axis
=
1
)
pphy2mlog
=
(
pphy2mlog
.
view
(
num_layers
,
num_nodes
,
-
1
)
+
torch
.
arange
(
pphy2mlog
.
reshape
(
num_layers
,
num_nodes
,
-
1
)
+
np
.
arange
(
0
,
num_logical_experts
,
num_logical_experts
//
num_nodes
,
device
=
group_pack_index
.
device
,
).
view
(
1
,
-
1
,
1
)
).
flatten
(
-
2
)
pphy2log
=
mlog2log
.
gather
(
-
1
,
pphy2mlog
)
pphy_replicas_idx
=
replicas_idx
.
gather
(
-
1
,
pphy2phy
).
view
(
num_layers
,
-
1
)
logcnt
=
mlogcnt
.
view
(
num_layers
,
-
1
).
gather
(
-
1
,
log2mlog
)
dtype
=
np
.
int64
,
)[
None
,
:,
None
]
).
reshape
(
num_layers
,
-
1
)
# Map node-local logical indices back to global logical expert ids.
pphy2log
=
np
.
take_along_axis
(
mlog2log
,
pphy2mlog
,
axis
=
1
)
# Reorder replica ranks to the post-packing physical ordering.
pphy_replicas_idx
=
np
.
take_along_axis
(
replicas_idx
,
pphy2phy
,
axis
=
1
).
reshape
(
num_layers
,
-
1
)
# Convert replica counts back to the original logical ordering.
logcnt
=
np
.
take_along_axis
(
mlogcnt
.
reshape
(
num_layers
,
-
1
),
log2mlog
,
axis
=
1
)
return
pphy2log
,
pphy_replicas_idx
,
logcnt
@
classmethod
def
preserve_intragpu_slots
(
cls
,
phy2log
:
torch
.
Tensor
,
phy_replicas_idx
:
torch
.
Tensor
,
phy2log
:
np
.
ndarray
,
phy_replicas_idx
:
np
.
ndarray
,
num_ranks
:
int
,
old_
global_expert_indices
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
old_
phy2log
:
np
.
ndarray
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
...
...
@@ -222,29 +216,24 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
is unchanged and the slots per GPU remain the same between
the old and new mappings.
"""
device
=
phy2log
.
device
num_phy_experts
=
phy2log
.
shape
[
1
]
if
num_ranks
<=
0
or
num_phy_experts
%
num_ranks
!=
0
:
return
phy2log
,
phy_replicas_idx
# Move to CPU and convert to NumPy for processing
new_phy2log_np
=
phy2log
.
cpu
().
numpy
()
replicas_idx_np
=
phy_replicas_idx
.
cpu
().
numpy
()
old_phy2log_np
=
old_global_expert_indices
.
cpu
().
numpy
()
slots_per_gpu
=
num_phy_experts
//
num_ranks
num_layers
=
new_
phy2log
_np
.
shape
[
0
]
num_layers
=
phy2log
.
shape
[
0
]
post_phy2log
_np
=
new_
phy2log
_np
.
copy
()
post_phy_replicas_idx
_np
=
replicas_idx
_np
.
copy
()
post_phy2log
=
phy2log
.
copy
()
post_phy_replicas_idx
=
phy_
replicas_idx
.
copy
()
for
gpu_idx
in
range
(
num_ranks
):
start
=
gpu_idx
*
slots_per_gpu
end
=
start
+
slots_per_gpu
# Experts across all layers for this GPU
old_local
=
old_phy2log
_np
[:,
start
:
end
]
# [layers, slots]
new_local
=
new_
phy2log
_np
[:,
start
:
end
]
# [layers, slots]
new_ridx
=
replicas_idx
_np
[:,
start
:
end
]
# [layers, slots]
old_local
=
old_phy2log
[:,
start
:
end
]
# [layers, slots]
new_local
=
phy2log
[:,
start
:
end
]
# [layers, slots]
new_ridx
=
phy_
replicas_idx
[:,
start
:
end
]
# [layers, slots]
used_new_indices
=
np
.
zeros
((
num_layers
,
slots_per_gpu
),
dtype
=
bool
)
preserved_positions
=
np
.
zeros
((
num_layers
,
slots_per_gpu
),
dtype
=
bool
)
...
...
@@ -261,12 +250,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
first_idx
=
np
.
argmax
(
matches
,
axis
=
1
)
layer_indices
=
np
.
nonzero
(
has_any
)[
0
]
matched_new_positions
=
first_idx
[
layer_indices
]
post_phy2log_np
[
layer_indices
,
start
+
slot_idx
]
=
new_local
[
post_phy2log
[
layer_indices
,
start
+
slot_idx
]
=
new_local
[
layer_indices
,
matched_new_positions
]
post_phy_replicas_idx
[
layer_indices
,
start
+
slot_idx
]
=
new_ridx
[
layer_indices
,
matched_new_positions
]
post_phy_replicas_idx_np
[
layer_indices
,
start
+
slot_idx
]
=
(
new_ridx
[
layer_indices
,
matched_new_positions
]
)
used_new_indices
[
layer_indices
,
matched_new_positions
]
=
True
preserved_positions
[
layer_indices
,
slot_idx
]
=
True
...
...
@@ -295,16 +284,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
continue
src_pos
=
remaining_indices
[
layer_idx
,
:
k
]
dst_pos
=
fill_indices
[
layer_idx
,
:
k
]
post_phy2log
_np
[
layer_idx
,
start
+
dst_pos
]
=
new_local
[
post_phy2log
[
layer_idx
,
start
+
dst_pos
]
=
new_local
[
layer_idx
,
src_pos
]
post_phy_replicas_idx
_np
[
layer_idx
,
start
+
dst_pos
]
=
new_ridx
[
post_phy_replicas_idx
[
layer_idx
,
start
+
dst_pos
]
=
new_ridx
[
layer_idx
,
src_pos
]
# Convert back to torch and move to original device
post_phy2log
=
torch
.
from_numpy
(
post_phy2log_np
).
to
(
device
)
post_phy_replicas_idx
=
torch
.
from_numpy
(
post_phy_replicas_idx_np
).
to
(
device
)
return
post_phy2log
,
post_phy_replicas_idx
@
classmethod
...
...
@@ -340,40 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
device
=
weight
.
device
num_layers
,
num_logical_experts
=
weight
.
shape
weight
=
weight
.
float
()
weight_np
=
weight
.
float
().
cpu
().
numpy
()
old_phy2log_np
=
(
old_global_expert_indices
.
cpu
().
numpy
()
if
old_global_expert_indices
is
not
None
else
None
)
if
num_groups
%
num_nodes
==
0
:
# use hierarchical load-balance policy
phy2log
,
phy_replicas_idx
,
logcnt
=
cls
.
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_ranks
phy2log_np
,
phy_replicas_idx_np
,
logcnt_np
=
(
cls
.
rebalance_experts_hierarchical
(
weight_np
,
num_replicas
,
num_groups
,
num_nodes
,
num_ranks
)
)
else
:
# use global load-balance policy
phy2log
,
phy_replicas_idx
,
logcnt
=
cls
.
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
1
,
1
,
num_ranks
phy2log_np
,
phy_replicas_idx_np
,
logcnt_np
=
(
cls
.
rebalance_experts_hierarchical
(
weight_np
,
num_replicas
,
1
,
1
,
num_ranks
)
)
# Optional postprocessing to preserve slots for experts moving
# within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
if
old_global_expert_indices
is
not
None
:
phy2log
,
phy_replicas_idx
=
cls
.
preserve_intragpu_slots
(
phy2log
,
phy_replicas_idx
,
num_ranks
,
old_
global_expert_indices
phy2log
_np
,
phy_replicas_idx
_np
=
cls
.
preserve_intragpu_slots
(
phy2log
_np
,
phy_replicas_idx
_np
,
num_ranks
,
old_
phy2log_np
)
num_redundant_experts
=
num_replicas
-
num_logical_experts
maxlogcnt
=
num_redundant_experts
+
1
log2phy
:
torch
.
Tensor
=
torch
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
torch
.
int64
,
device
=
logcnt
.
device
,
log2phy_np
=
np
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
np
.
int64
)
log2phy
.
view
(
num_layers
,
-
1
).
scatter_
(
-
1
,
phy2log
*
maxlogcnt
+
phy_replicas_idx
,
torch
.
arange
(
num_replicas
,
dtype
=
torch
.
int64
,
device
=
log2phy
.
device
).
expand
(
num_layers
,
-
1
),
layer_indices
=
np
.
arange
(
num_layers
)[:,
None
]
replica_indices
=
np
.
tile
(
np
.
arange
(
num_replicas
,
dtype
=
np
.
int64
),
(
num_layers
,
1
)
)
log2phy_np
[
layer_indices
,
phy2log_np
,
phy_replicas_idx_np
]
=
replica_indices
phy2log
=
torch
.
from_numpy
(
phy2log_np
).
to
(
device
)
log2phy
=
torch
.
from_numpy
(
log2phy_np
).
to
(
device
)
logcnt
=
torch
.
from_numpy
(
logcnt_np
).
to
(
device
)
return
phy2log
,
log2phy
,
logcnt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment