Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1eb61ab3
Unverified
Commit
1eb61ab3
authored
Jan 12, 2026
by
Ilya Markov
Committed by
GitHub
Jan 12, 2026
Browse files
[Refactor] EPLB rebalance algo to NumPy (#30697)
Signed-off-by:
ilmarkov
<
markovilya197@gmail.com
>
parent
3d962d72
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
128 additions
and
130 deletions
+128
-130
tests/distributed/test_eplb_algo.py
tests/distributed/test_eplb_algo.py
+16
-15
tests/distributed/test_eplb_execute.py
tests/distributed/test_eplb_execute.py
+1
-1
vllm/distributed/eplb/policy/default.py
vllm/distributed/eplb/policy/default.py
+111
-114
No files found.
tests/distributed/test_eplb_algo.py
View file @
1eb61ab3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
...
@@ -312,9 +313,9 @@ if __name__ == "__main__":
...
@@ -312,9 +313,9 @@ if __name__ == "__main__":
test_basic_rebalance
()
test_basic_rebalance
()
def
_make_phy_replicas_idx_from_phy2log
(
phy2log
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_make_phy_replicas_idx_from_phy2log
(
phy2log
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Create replicas indices mapping from phy2log"""
"""Create replicas indices mapping from phy2log
.
"""
pr
=
torch
.
zeros_like
(
phy2log
)
pr
=
np
.
zeros_like
(
phy2log
,
dtype
=
np
.
int64
)
for
layer
in
range
(
phy2log
.
shape
[
0
]):
for
layer
in
range
(
phy2log
.
shape
[
0
]):
seen
:
dict
[
int
,
int
]
=
{}
seen
:
dict
[
int
,
int
]
=
{}
row
=
phy2log
[
layer
].
tolist
()
row
=
phy2log
[
layer
].
tolist
()
...
@@ -326,11 +327,11 @@ def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
...
@@ -326,11 +327,11 @@ def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
def
_validate_intragpu_rearrangement
(
def
_validate_intragpu_rearrangement
(
old_global_expert_indices
:
torch
.
Tensor
,
old_global_expert_indices
:
np
.
ndarray
,
new_phy2log
:
torch
.
Tensor
,
new_phy2log
:
np
.
ndarray
,
new_phy_replicas_idx
:
torch
.
Tensor
,
new_phy_replicas_idx
:
np
.
ndarray
,
post_phy2log
:
torch
.
Tensor
,
post_phy2log
:
np
.
ndarray
,
post_phy_replicas_idx
:
torch
.
Tensor
,
post_phy_replicas_idx
:
np
.
ndarray
,
num_ranks
:
int
,
num_ranks
:
int
,
slots_per_gpu
:
int
,
slots_per_gpu
:
int
,
):
):
...
@@ -345,7 +346,7 @@ def _validate_intragpu_rearrangement(
...
@@ -345,7 +346,7 @@ def _validate_intragpu_rearrangement(
post_rnk
=
post_phy_replicas_idx
[
0
,
start
:
end
]
post_rnk
=
post_phy_replicas_idx
[
0
,
start
:
end
]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def
sorted_pairs
(
seg
:
torch
.
Tensor
,
rnk
:
torch
.
Tensor
):
def
sorted_pairs
(
seg
,
rnk
):
pairs
=
list
(
zip
(
seg
.
tolist
(),
rnk
.
tolist
()))
pairs
=
list
(
zip
(
seg
.
tolist
(),
rnk
.
tolist
()))
pairs
.
sort
()
pairs
.
sort
()
return
pairs
return
pairs
...
@@ -386,8 +387,8 @@ def _validate_intragpu_rearrangement(
...
@@ -386,8 +387,8 @@ def _validate_intragpu_rearrangement(
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2
,
2
,
4
,
4
,
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]]),
np
.
array
([[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]]),
torch
.
tensor
([[
1
,
5
,
0
,
4
,
6
,
2
,
7
,
3
]]),
np
.
array
([[
1
,
5
,
0
,
4
,
6
,
2
,
7
,
3
]]),
id
=
"simple"
,
id
=
"simple"
,
),
),
pytest
.
param
(
pytest
.
param
(
...
@@ -401,8 +402,8 @@ def _validate_intragpu_rearrangement(
...
@@ -401,8 +402,8 @@ def _validate_intragpu_rearrangement(
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2
,
2
,
5
,
5
,
torch
.
tensor
([[
0
,
1
,
0
,
2
,
3
,
4
,
5
,
6
,
1
,
2
]]),
np
.
array
([[
0
,
1
,
0
,
2
,
3
,
4
,
5
,
6
,
1
,
2
]]),
torch
.
tensor
([[
0
,
5
,
4
,
0
,
1
,
6
,
2
,
3
,
2
,
1
]]),
np
.
array
([[
0
,
5
,
4
,
0
,
1
,
6
,
2
,
3
,
2
,
1
]]),
id
=
"duplicates"
,
id
=
"duplicates"
,
),
),
pytest
.
param
(
pytest
.
param
(
...
@@ -418,8 +419,8 @@ def _validate_intragpu_rearrangement(
...
@@ -418,8 +419,8 @@ def _validate_intragpu_rearrangement(
# GPU2 new -> [1, 2, 3, 0]
# GPU2 new -> [1, 2, 3, 0]
3
,
3
,
4
,
4
,
torch
.
tensor
([[
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
]]),
np
.
array
([[
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
]]),
torch
.
tensor
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
]]),
np
.
array
([[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
]]),
id
=
"skewed_expert"
,
id
=
"skewed_expert"
,
),
),
],
],
...
...
tests/distributed/test_eplb_execute.py
View file @
1eb61ab3
...
@@ -311,7 +311,7 @@ def _test_async_transfer_layer_without_mtp_worker(
...
@@ -311,7 +311,7 @@ def _test_async_transfer_layer_without_mtp_worker(
is_unchanged
=
is_unchanged
,
is_unchanged
=
is_unchanged
,
is_received_locally
=
is_received_locally
,
is_received_locally
=
is_received_locally
,
recv_metadata
=
recv_metadata
,
recv_metadata
=
recv_metadata
,
new_indices
=
new_indices_cpu
[
layer_idx
],
new_indices
=
new_indices_cpu
[
layer_idx
]
.
numpy
()
,
ep_rank
=
ep_rank
,
ep_rank
=
ep_rank
,
)
)
...
...
vllm/distributed/eplb/policy/default.py
View file @
1eb61ab3
...
@@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy
...
@@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy
class
DefaultEplbPolicy
(
AbstractEplbPolicy
):
class
DefaultEplbPolicy
(
AbstractEplbPolicy
):
@
classmethod
@
classmethod
def
balanced_packing
(
def
balanced_packing
(
cls
,
weight
:
torch
.
Tensor
,
num_packs
:
int
cls
,
weight
:
np
.
ndarray
,
num_packs
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
Pack n weighted objects to m packs, such that each bin contains exactly
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
n/m objects and the weights of all packs are as balanced as possible.
...
@@ -39,50 +39,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -39,50 +39,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert
num_groups
%
num_packs
==
0
assert
num_groups
%
num_packs
==
0
groups_per_pack
=
num_groups
//
num_packs
groups_per_pack
=
num_groups
//
num_packs
device
=
weight
.
device
if
groups_per_pack
==
1
:
if
groups_per_pack
==
1
:
pack_index
=
torch
.
arange
(
pack_index
=
np
.
tile
(
np
.
arange
(
num_groups
,
dtype
=
np
.
int64
),
(
num_layers
,
1
))
weight
.
size
(
-
1
),
dtype
=
torch
.
int64
,
device
=
device
rank_in_pack
=
np
.
zeros_like
(
pack_index
,
dtype
=
np
.
int64
)
).
expand
(
weight
.
shape
)
rank_in_pack
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int64
,
device
=
device
)
return
pack_index
,
rank_in_pack
return
pack_index
,
rank_in_pack
weight_np
=
weight
.
cpu
().
numpy
()
# Sort and get indices in decending order
# Sort and get indices in decending order
indices_np
=
np
.
argsort
(
-
weight_np
,
axis
=-
1
)
indices
=
np
.
argsort
(
-
weight
,
axis
=-
1
)
pack_index
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
rank_in_pack
=
np
.
full
((
num_layers
,
num_groups
),
-
1
,
dtype
=
np
.
int64
)
pack_
index_np
=
np
.
full
((
num_layers
,
num_
groups
),
-
1
,
dtype
=
np
.
in
t64
)
pack_
weights
=
np
.
zeros
((
num_layers
,
num_
packs
)
,
dtype
=
np
.
floa
t64
)
rank_in_pack_np
=
np
.
full
((
num_layers
,
num_
groups
),
-
1
,
dtype
=
np
.
int64
)
pack_items
=
np
.
zeros
((
num_layers
,
num_
packs
)
,
dtype
=
np
.
int64
)
# Run the packing algorithm
# Run the packing algorithm
for
i
in
range
(
num_layers
):
for
layer_idx
in
range
(
num_layers
):
pack_weights
=
[
0.0
]
*
num_packs
weights_row
=
pack_weights
[
layer_idx
]
pack_items
=
[
0
]
*
num_packs
items_row
=
pack_items
[
layer_idx
]
for
group
in
indices_np
[
i
]:
# Find a pack with capacity that has the lowest weight
pack
=
min
(
(
j
for
j
in
range
(
num_packs
)
if
pack_items
[
j
]
<
groups_per_pack
),
key
=
pack_weights
.
__getitem__
,
)
assert
pack_items
[
pack
]
<
groups_per_pack
for
group
in
indices
[
layer_idx
]:
pack_index_np
[
i
,
group
]
=
pack
# Pick the lightest pack; full packs are masked out by inf.
rank_in_pack_np
[
i
,
group
]
=
pack_items
[
pack
]
pack
=
int
(
np
.
argmin
(
weights_row
))
pack_weights
[
pack
]
+=
weight_np
[
i
,
group
]
pack_items
[
pack
]
+=
1
pack_index
=
torch
.
from_numpy
(
pack_index_np
).
to
(
device
)
pack_index
[
layer_idx
,
group
]
=
pack
rank_in_pack
=
torch
.
from_numpy
(
rank_in_pack_np
).
to
(
device
)
rank_in_pack
[
layer_idx
,
group
]
=
items_row
[
pack
]
weights_row
[
pack
]
+=
weight
[
layer_idx
,
group
]
items_row
[
pack
]
+=
1
if
items_row
[
pack
]
==
groups_per_pack
:
# Mark as unavailable for future selections.
weights_row
[
pack
]
=
np
.
inf
return
pack_index
,
rank_in_pack
return
pack_index
,
rank_in_pack
@
classmethod
@
classmethod
def
replicate_experts
(
def
replicate_experts
(
cls
,
weight
:
torch
.
Tensor
,
num_phy
:
int
cls
,
weight
:
np
.
ndarray
,
num_phy
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
load of all replicas is minimized.
...
@@ -99,13 +92,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -99,13 +92,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
n
,
num_log
=
weight
.
shape
n
,
num_log
=
weight
.
shape
num_redundant
=
num_phy
-
num_log
num_redundant
=
num_phy
-
num_log
assert
num_redundant
>=
0
assert
num_redundant
>=
0
device
=
weight
.
device
phy2log
=
np
.
tile
(
np
.
arange
(
num_phy
,
dtype
=
np
.
int64
),
(
n
,
1
))
phy2log
=
torch
.
arange
(
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
).
repeat
(
n
,
1
)
replica_idx
=
np
.
zeros
((
n
,
num_phy
),
dtype
=
np
.
int64
)
replica_idx
=
torch
.
zeros
(
n
,
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
)
logcnt
=
np
.
ones
((
n
,
num_log
),
dtype
=
np
.
int64
)
logcnt
=
torch
.
ones
(
n
,
num_log
,
dtype
=
torch
.
int64
,
device
=
device
)
arangen
=
np
.
arange
(
n
,
dtype
=
np
.
int64
)
arangen
=
torch
.
arange
(
n
,
dtype
=
torch
.
int64
,
device
=
device
)
for
i
in
range
(
num_log
,
num_phy
):
for
i
in
range
(
num_log
,
num_phy
):
redundant_indices
=
(
weight
/
logcnt
).
max
(
dim
=-
1
).
indices
redundant_indices
=
np
.
argmax
(
weight
/
logcnt
,
axis
=-
1
)
phy2log
[:,
i
]
=
redundant_indices
phy2log
[:,
i
]
=
redundant_indices
replica_idx
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
replica_idx
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
logcnt
[
arangen
,
redundant_indices
]
+=
1
logcnt
[
arangen
,
redundant_indices
]
+=
1
...
@@ -114,12 +106,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -114,12 +106,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
@
classmethod
@
classmethod
def
rebalance_experts_hierarchical
(
def
rebalance_experts_hierarchical
(
cls
,
cls
,
weight
:
torch
.
Tensor
,
weight
:
np
.
ndarray
,
num_physical_experts
:
int
,
num_physical_experts
:
int
,
num_groups
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
num_gpus
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
Parameters:
Parameters:
weight: [num_moe_layers, num_logical_experts]
weight: [num_moe_layers, num_logical_experts]
...
@@ -146,35 +138,33 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -146,35 +138,33 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert
num_physical_experts
%
num_gpus
==
0
assert
num_physical_experts
%
num_gpus
==
0
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
def
inverse
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
inverse
(
perm
:
np
.
ndarray
)
->
np
.
ndarray
:
inv
=
torch
.
empty_like
(
perm
)
inv
=
np
.
empty_like
(
perm
)
inv
.
scatter_
(
row_idx
=
np
.
arange
(
perm
.
shape
[
0
])[:,
None
]
1
,
col_idx
=
np
.
arange
(
perm
.
shape
[
1
],
dtype
=
np
.
int64
)
perm
,
inv
[
row_idx
,
perm
]
=
col_idx
torch
.
arange
(
perm
.
size
(
1
),
dtype
=
torch
.
int64
,
device
=
perm
.
device
).
expand
(
perm
.
shape
),
)
return
inv
return
inv
# Step 1: pack groups to nodes
# Step 1: pack groups to nodes
tokens_per_group
=
weight
.
unflatten
(
-
1
,
(
num_groups
,
group_size
)).
sum
(
-
1
)
tokens_per_group
=
weight
.
reshape
(
num_layers
,
num_groups
,
group_size
).
sum
(
axis
=-
1
)
group_pack_index
,
group_rank_in_pack
=
cls
.
balanced_packing
(
group_pack_index
,
group_rank_in_pack
=
cls
.
balanced_packing
(
tokens_per_group
,
num_nodes
tokens_per_group
,
num_nodes
)
)
# Map each logical expert into a node-local ordering based on packed groups.
log2mlog
=
(
log2mlog
=
(
(
(
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)
*
group_size
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)[...,
None
]
).
unsqueeze
(
-
1
)
*
group_size
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_pack_index
.
device
)
)
).
flatten
(
-
2
)
+
np
.
arange
(
group_size
,
dtype
=
np
.
int64
)
).
reshape
(
num_layers
,
num_logical_experts
)
mlog2log
=
inverse
(
log2mlog
)
mlog2log
=
inverse
(
log2mlog
)
# Step 2: construct redundant experts within nodes
# Step 2: construct redundant experts within nodes
#
[num_layers * num_nodes, num_logical_experts // num_
node
s]
#
Reorder weights into the node-local layout so replication is done per
node
.
tokens_per_mlog
=
weight
.
gather
(
-
1
,
mlog2log
).
view
(
tokens_per_mlog
=
np
.
take_along_axis
(
weight
,
mlog2log
,
axis
=
1
).
reshape
(
-
1
,
num_logical_experts
//
num_nodes
-
1
,
num_logical_experts
//
num_nodes
)
)
phy2mlog
,
replicas_idx
,
mlogcnt
=
cls
.
replicate_experts
(
phy2mlog
,
replicas_idx
,
mlogcnt
=
cls
.
replicate_experts
(
...
@@ -182,39 +172,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -182,39 +172,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
)
)
# Step 3: pack physical_experts to GPUs
# Step 3: pack physical_experts to GPUs
#
[num_layers * num_nodes, num_physical_experts // num_nodes]
#
Effective per-physical load = logical load divided by replica count.
tokens_per_phy
=
(
tokens_per_mlog
/
mlogcnt
).
gather
(
-
1
,
phy2mlog
)
tokens_per_phy
=
np
.
take_along_axis
(
tokens_per_mlog
/
mlogcnt
,
phy2mlog
,
axis
=
1
)
pack_index
,
rank_in_pack
=
cls
.
balanced_packing
(
pack_index
,
rank_in_pack
=
cls
.
balanced_packing
(
tokens_per_phy
,
num_gpus
//
num_nodes
tokens_per_phy
,
num_gpus
//
num_nodes
)
)
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
pphy2phy
=
inverse
(
phy2pphy
)
pphy2phy
=
inverse
(
phy2pphy
)
pphy2mlog
=
phy2mlog
.
gather
(
# Reorder node-local logical indices into the post-packing physical order.
-
1
,
pphy2phy
pphy2mlog
=
np
.
take_along_axis
(
phy2mlog
,
pphy2phy
,
axis
=
1
)
)
# [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog
=
(
pphy2mlog
=
(
pphy2mlog
.
view
(
num_layers
,
num_nodes
,
-
1
)
pphy2mlog
.
reshape
(
num_layers
,
num_nodes
,
-
1
)
+
torch
.
arange
(
+
np
.
arange
(
0
,
0
,
num_logical_experts
,
num_logical_experts
,
num_logical_experts
//
num_nodes
,
num_logical_experts
//
num_nodes
,
device
=
group_pack_index
.
device
,
dtype
=
np
.
int64
,
).
view
(
1
,
-
1
,
1
)
)[
None
,
:,
None
]
).
flatten
(
-
2
)
).
reshape
(
num_layers
,
-
1
)
pphy2log
=
mlog2log
.
gather
(
-
1
,
pphy2mlog
)
# Map node-local logical indices back to global logical expert ids.
pphy_replicas_idx
=
replicas_idx
.
gather
(
-
1
,
pphy2phy
).
view
(
num_layers
,
-
1
)
pphy2log
=
np
.
take_along_axis
(
mlog2log
,
pphy2mlog
,
axis
=
1
)
logcnt
=
mlogcnt
.
view
(
num_layers
,
-
1
).
gather
(
-
1
,
log2mlog
)
# Reorder replica ranks to the post-packing physical ordering.
pphy_replicas_idx
=
np
.
take_along_axis
(
replicas_idx
,
pphy2phy
,
axis
=
1
).
reshape
(
num_layers
,
-
1
)
# Convert replica counts back to the original logical ordering.
logcnt
=
np
.
take_along_axis
(
mlogcnt
.
reshape
(
num_layers
,
-
1
),
log2mlog
,
axis
=
1
)
return
pphy2log
,
pphy_replicas_idx
,
logcnt
return
pphy2log
,
pphy_replicas_idx
,
logcnt
@
classmethod
@
classmethod
def
preserve_intragpu_slots
(
def
preserve_intragpu_slots
(
cls
,
cls
,
phy2log
:
torch
.
Tensor
,
phy2log
:
np
.
ndarray
,
phy_replicas_idx
:
torch
.
Tensor
,
phy_replicas_idx
:
np
.
ndarray
,
num_ranks
:
int
,
num_ranks
:
int
,
old_
global_expert_indices
:
torch
.
Tensor
,
old_
phy2log
:
np
.
ndarray
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
keep their previous slot positions when possible. Incoming experts to that GPU
...
@@ -222,29 +216,24 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -222,29 +216,24 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
is unchanged and the slots per GPU remain the same between
is unchanged and the slots per GPU remain the same between
the old and new mappings.
the old and new mappings.
"""
"""
device
=
phy2log
.
device
num_phy_experts
=
phy2log
.
shape
[
1
]
num_phy_experts
=
phy2log
.
shape
[
1
]
if
num_ranks
<=
0
or
num_phy_experts
%
num_ranks
!=
0
:
if
num_ranks
<=
0
or
num_phy_experts
%
num_ranks
!=
0
:
return
phy2log
,
phy_replicas_idx
return
phy2log
,
phy_replicas_idx
# Move to CPU and convert to NumPy for processing
# Move to CPU and convert to NumPy for processing
new_phy2log_np
=
phy2log
.
cpu
().
numpy
()
replicas_idx_np
=
phy_replicas_idx
.
cpu
().
numpy
()
old_phy2log_np
=
old_global_expert_indices
.
cpu
().
numpy
()
slots_per_gpu
=
num_phy_experts
//
num_ranks
slots_per_gpu
=
num_phy_experts
//
num_ranks
num_layers
=
new_
phy2log
_np
.
shape
[
0
]
num_layers
=
phy2log
.
shape
[
0
]
post_phy2log
_np
=
new_
phy2log
_np
.
copy
()
post_phy2log
=
phy2log
.
copy
()
post_phy_replicas_idx
_np
=
replicas_idx
_np
.
copy
()
post_phy_replicas_idx
=
phy_
replicas_idx
.
copy
()
for
gpu_idx
in
range
(
num_ranks
):
for
gpu_idx
in
range
(
num_ranks
):
start
=
gpu_idx
*
slots_per_gpu
start
=
gpu_idx
*
slots_per_gpu
end
=
start
+
slots_per_gpu
end
=
start
+
slots_per_gpu
# Experts across all layers for this GPU
# Experts across all layers for this GPU
old_local
=
old_phy2log
_np
[:,
start
:
end
]
# [layers, slots]
old_local
=
old_phy2log
[:,
start
:
end
]
# [layers, slots]
new_local
=
new_
phy2log
_np
[:,
start
:
end
]
# [layers, slots]
new_local
=
phy2log
[:,
start
:
end
]
# [layers, slots]
new_ridx
=
replicas_idx
_np
[:,
start
:
end
]
# [layers, slots]
new_ridx
=
phy_
replicas_idx
[:,
start
:
end
]
# [layers, slots]
used_new_indices
=
np
.
zeros
((
num_layers
,
slots_per_gpu
),
dtype
=
bool
)
used_new_indices
=
np
.
zeros
((
num_layers
,
slots_per_gpu
),
dtype
=
bool
)
preserved_positions
=
np
.
zeros
((
num_layers
,
slots_per_gpu
),
dtype
=
bool
)
preserved_positions
=
np
.
zeros
((
num_layers
,
slots_per_gpu
),
dtype
=
bool
)
...
@@ -261,12 +250,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -261,12 +250,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
first_idx
=
np
.
argmax
(
matches
,
axis
=
1
)
first_idx
=
np
.
argmax
(
matches
,
axis
=
1
)
layer_indices
=
np
.
nonzero
(
has_any
)[
0
]
layer_indices
=
np
.
nonzero
(
has_any
)[
0
]
matched_new_positions
=
first_idx
[
layer_indices
]
matched_new_positions
=
first_idx
[
layer_indices
]
post_phy2log_np
[
layer_indices
,
start
+
slot_idx
]
=
new_local
[
post_phy2log
[
layer_indices
,
start
+
slot_idx
]
=
new_local
[
layer_indices
,
matched_new_positions
]
post_phy_replicas_idx
[
layer_indices
,
start
+
slot_idx
]
=
new_ridx
[
layer_indices
,
matched_new_positions
layer_indices
,
matched_new_positions
]
]
post_phy_replicas_idx_np
[
layer_indices
,
start
+
slot_idx
]
=
(
new_ridx
[
layer_indices
,
matched_new_positions
]
)
used_new_indices
[
layer_indices
,
matched_new_positions
]
=
True
used_new_indices
[
layer_indices
,
matched_new_positions
]
=
True
preserved_positions
[
layer_indices
,
slot_idx
]
=
True
preserved_positions
[
layer_indices
,
slot_idx
]
=
True
...
@@ -295,16 +284,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -295,16 +284,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
continue
continue
src_pos
=
remaining_indices
[
layer_idx
,
:
k
]
src_pos
=
remaining_indices
[
layer_idx
,
:
k
]
dst_pos
=
fill_indices
[
layer_idx
,
:
k
]
dst_pos
=
fill_indices
[
layer_idx
,
:
k
]
post_phy2log
_np
[
layer_idx
,
start
+
dst_pos
]
=
new_local
[
post_phy2log
[
layer_idx
,
start
+
dst_pos
]
=
new_local
[
layer_idx
,
src_pos
layer_idx
,
src_pos
]
]
post_phy_replicas_idx
_np
[
layer_idx
,
start
+
dst_pos
]
=
new_ridx
[
post_phy_replicas_idx
[
layer_idx
,
start
+
dst_pos
]
=
new_ridx
[
layer_idx
,
src_pos
layer_idx
,
src_pos
]
]
# Convert back to torch and move to original device
post_phy2log
=
torch
.
from_numpy
(
post_phy2log_np
).
to
(
device
)
post_phy_replicas_idx
=
torch
.
from_numpy
(
post_phy_replicas_idx_np
).
to
(
device
)
return
post_phy2log
,
post_phy_replicas_idx
return
post_phy2log
,
post_phy_replicas_idx
@
classmethod
@
classmethod
...
@@ -340,40 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
...
@@ -340,40 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
logcnt: [layers, num_logical_experts], number of
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
physical replicas for each logical expert
"""
"""
device
=
weight
.
device
num_layers
,
num_logical_experts
=
weight
.
shape
num_layers
,
num_logical_experts
=
weight
.
shape
weight
=
weight
.
float
()
weight_np
=
weight
.
float
().
cpu
().
numpy
()
old_phy2log_np
=
(
old_global_expert_indices
.
cpu
().
numpy
()
if
old_global_expert_indices
is
not
None
else
None
)
if
num_groups
%
num_nodes
==
0
:
if
num_groups
%
num_nodes
==
0
:
# use hierarchical load-balance policy
# use hierarchical load-balance policy
phy2log
,
phy_replicas_idx
,
logcnt
=
cls
.
rebalance_experts_hierarchical
(
phy2log_np
,
phy_replicas_idx_np
,
logcnt_np
=
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_ranks
cls
.
rebalance_experts_hierarchical
(
weight_np
,
num_replicas
,
num_groups
,
num_nodes
,
num_ranks
)
)
)
else
:
else
:
# use global load-balance policy
# use global load-balance policy
phy2log
,
phy_replicas_idx
,
logcnt
=
cls
.
rebalance_experts_hierarchical
(
phy2log_np
,
phy_replicas_idx_np
,
logcnt_np
=
(
weight
,
num_replicas
,
1
,
1
,
num_ranks
cls
.
rebalance_experts_hierarchical
(
weight_np
,
num_replicas
,
1
,
1
,
num_ranks
)
)
)
# Optional postprocessing to preserve slots for experts moving
# Optional postprocessing to preserve slots for experts moving
# within the same GPU
# within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
# within the same GPU.
if
old_global_expert_indices
is
not
None
:
if
old_global_expert_indices
is
not
None
:
phy2log
,
phy_replicas_idx
=
cls
.
preserve_intragpu_slots
(
phy2log
_np
,
phy_replicas_idx
_np
=
cls
.
preserve_intragpu_slots
(
phy2log
,
phy_replicas_idx
,
num_ranks
,
old_
global_expert_indices
phy2log
_np
,
phy_replicas_idx
_np
,
num_ranks
,
old_
phy2log_np
)
)
num_redundant_experts
=
num_replicas
-
num_logical_experts
num_redundant_experts
=
num_replicas
-
num_logical_experts
maxlogcnt
=
num_redundant_experts
+
1
maxlogcnt
=
num_redundant_experts
+
1
log2phy
:
torch
.
Tensor
=
torch
.
full
(
log2phy_np
=
np
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
np
.
int64
-
1
,
dtype
=
torch
.
int64
,
device
=
logcnt
.
device
,
)
)
log2phy
.
view
(
num_layers
,
-
1
).
scatter_
(
layer_indices
=
np
.
arange
(
num_layers
)[:,
None
]
-
1
,
replica_indices
=
np
.
tile
(
phy2log
*
maxlogcnt
+
phy_replicas_idx
,
np
.
arange
(
num_replicas
,
dtype
=
np
.
int64
),
(
num_layers
,
1
)
torch
.
arange
(
num_replicas
,
dtype
=
torch
.
int64
,
device
=
log2phy
.
device
).
expand
(
num_layers
,
-
1
),
)
)
log2phy_np
[
layer_indices
,
phy2log_np
,
phy_replicas_idx_np
]
=
replica_indices
phy2log
=
torch
.
from_numpy
(
phy2log_np
).
to
(
device
)
log2phy
=
torch
.
from_numpy
(
log2phy_np
).
to
(
device
)
logcnt
=
torch
.
from_numpy
(
logcnt_np
).
to
(
device
)
return
phy2log
,
log2phy
,
logcnt
return
phy2log
,
log2phy
,
logcnt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment