Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3ab7d9b5
Unverified
Commit
3ab7d9b5
authored
May 29, 2025
by
fzyzcjy
Committed by
GitHub
May 29, 2025
Browse files
Support picking variants of EPLB algorithms (#6728)
parent
7e5071c9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
313 additions
and
18 deletions
+313
-18
python/sglang/srt/managers/eplb_algorithms/__init__.py
python/sglang/srt/managers/eplb_algorithms/__init__.py
+63
-0
python/sglang/srt/managers/eplb_algorithms/deepseek.py
python/sglang/srt/managers/eplb_algorithms/deepseek.py
+223
-0
python/sglang/srt/managers/eplb_algorithms/deepseek_vec.py
python/sglang/srt/managers/eplb_algorithms/deepseek_vec.py
+9
-13
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+11
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/managers/eplb_algorithms/__init__.py
0 → 100644
View file @
3ab7d9b5
from
enum
import
Enum
,
auto
from
typing
import
Optional
import
torch
from
sglang.srt.managers.eplb_algorithms
import
deepseek
,
deepseek_vec
class
EplbAlgorithm
(
Enum
):
deepseek
=
auto
()
deepseek_hierarchical
=
auto
()
deepseek_vec
=
auto
()
deepseek_vec_hierarchical
=
auto
()
# TODO may have more algorithm later
def
rebalance_experts
(
tokens_per_expert
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
num_groups
:
Optional
[
int
],
num_nodes
:
int
,
algorithm
:
EplbAlgorithm
,
):
if
algorithm
in
[
EplbAlgorithm
.
deepseek
,
EplbAlgorithm
.
deepseek_hierarchical
]:
return
deepseek
.
rebalance_experts
(
weight
=
tokens_per_expert
.
sum
(
dim
=
0
),
num_replicas
=
num_physical_experts
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
num_gpus
=
num_physical_experts
//
num_local_physical_experts
,
enable_hierarchical
=
algorithm
==
EplbAlgorithm
.
deepseek_hierarchical
,
)
if
algorithm
in
[
EplbAlgorithm
.
deepseek_vec
,
EplbAlgorithm
.
deepseek_vec_hierarchical
,
]:
return
deepseek_vec
.
rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
enable_hierarchical
=
algorithm
==
EplbAlgorithm
.
deepseek_vec_hierarchical
,
)
raise
NotImplementedError
def
compute_algorithm
(
raw_algorithm
:
str
,
num_groups
:
Optional
[
int
],
num_nodes
:
int
,
)
->
EplbAlgorithm
:
if
raw_algorithm
!=
"auto"
:
return
EplbAlgorithm
[
raw_algorithm
]
# TODO test on real scenarios and know which ones perform better
if
(
num_groups
is
not
None
)
and
(
num_groups
%
num_nodes
==
0
):
return
EplbAlgorithm
.
deepseek_hierarchical
else
:
return
EplbAlgorithm
.
deepseek
python/sglang/srt/managers/eplb_algorithms/deepseek.py
0 → 100644
View file @
3ab7d9b5
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from
typing
import
Tuple
import
torch
from
sglang.srt.utils
import
get_bool_env_var
def
balanced_packing
(
weight
:
torch
.
Tensor
,
num_packs
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers
,
num_groups
=
weight
.
shape
assert
num_groups
%
num_packs
==
0
groups_per_pack
=
num_groups
//
num_packs
if
groups_per_pack
==
1
:
pack_index
=
torch
.
arange
(
weight
.
size
(
-
1
),
dtype
=
torch
.
int64
,
device
=
weight
.
device
).
expand
(
weight
.
shape
)
rank_in_pack
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int64
)
return
pack_index
,
rank_in_pack
indices
=
weight
.
float
().
sort
(
-
1
,
descending
=
True
).
indices
.
cpu
()
pack_index
=
torch
.
full_like
(
weight
,
fill_value
=-
1
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
rank_in_pack
=
torch
.
full_like
(
pack_index
,
fill_value
=-
1
)
for
i
in
range
(
num_layers
):
pack_weights
=
[
0
]
*
num_packs
pack_items
=
[
0
]
*
num_packs
for
group
in
indices
[
i
]:
pack
=
min
(
(
i
for
i
in
range
(
num_packs
)
if
pack_items
[
i
]
<
groups_per_pack
),
key
=
pack_weights
.
__getitem__
,
)
assert
pack_items
[
pack
]
<
groups_per_pack
pack_index
[
i
,
group
]
=
pack
rank_in_pack
[
i
,
group
]
=
pack_items
[
pack
]
pack_weights
[
pack
]
+=
weight
[
i
,
group
]
pack_items
[
pack
]
+=
1
return
pack_index
,
rank_in_pack
def
replicate_experts
(
weight
:
torch
.
Tensor
,
num_phy
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n
,
num_log
=
weight
.
shape
num_redundant
=
num_phy
-
num_log
assert
num_redundant
>=
0
device
=
weight
.
device
phy2log
=
torch
.
arange
(
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
).
repeat
(
n
,
1
)
rank
=
torch
.
zeros
(
n
,
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
)
logcnt
=
torch
.
ones
(
n
,
num_log
,
dtype
=
torch
.
int64
,
device
=
device
)
arangen
=
torch
.
arange
(
n
,
dtype
=
torch
.
int64
,
device
=
device
)
for
i
in
range
(
num_log
,
num_phy
):
redundant_indices
=
(
weight
/
logcnt
).
max
(
dim
=-
1
).
indices
phy2log
[:,
i
]
=
redundant_indices
rank
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
logcnt
[
arangen
,
redundant_indices
]
+=
1
return
phy2log
,
rank
,
logcnt
def
rebalance_experts_hierarchical
(
weight
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers
,
num_logical_experts
=
weight
.
shape
assert
num_logical_experts
%
num_groups
==
0
group_size
=
num_logical_experts
//
num_groups
assert
num_groups
%
num_nodes
==
0
groups_per_node
=
num_groups
//
num_nodes
assert
num_gpus
%
num_nodes
==
0
assert
num_physical_experts
%
num_gpus
==
0
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
def
inverse
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
inv
=
torch
.
empty_like
(
perm
)
inv
.
scatter_
(
1
,
perm
,
torch
.
arange
(
perm
.
size
(
1
),
dtype
=
torch
.
int64
,
device
=
perm
.
device
).
expand
(
perm
.
shape
),
)
return
inv
# Step 1: pack groups to nodes
tokens_per_group
=
weight
.
unflatten
(
-
1
,
(
num_groups
,
group_size
)).
sum
(
-
1
)
group_pack_index
,
group_rank_in_pack
=
balanced_packing
(
tokens_per_group
,
num_nodes
)
log2mlog
=
(
(
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)
*
group_size
).
unsqueeze
(
-
1
)
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_pack_index
.
device
)
).
flatten
(
-
2
)
mlog2log
=
inverse
(
log2mlog
)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog
=
weight
.
gather
(
-
1
,
mlog2log
).
view
(
-
1
,
num_logical_experts
//
num_nodes
)
phy2mlog
,
phyrank
,
mlogcnt
=
replicate_experts
(
tokens_per_mlog
,
num_physical_experts
//
num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy
=
(
tokens_per_mlog
/
mlogcnt
).
gather
(
-
1
,
phy2mlog
)
pack_index
,
rank_in_pack
=
balanced_packing
(
tokens_per_phy
,
num_gpus
//
num_nodes
)
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
pphy2phy
=
inverse
(
phy2pphy
)
pphy2mlog
=
phy2mlog
.
gather
(
-
1
,
pphy2phy
)
# [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog
=
(
pphy2mlog
.
view
(
num_layers
,
num_nodes
,
-
1
)
+
torch
.
arange
(
0
,
num_logical_experts
,
num_logical_experts
//
num_nodes
,
device
=
group_pack_index
.
device
,
).
view
(
1
,
-
1
,
1
)
).
flatten
(
-
2
)
pphy2log
=
mlog2log
.
gather
(
-
1
,
pphy2mlog
)
pphyrank
=
phyrank
.
gather
(
-
1
,
pphy2phy
).
view
(
num_layers
,
-
1
)
logcnt
=
mlogcnt
.
view
(
num_layers
,
-
1
).
gather
(
-
1
,
log2mlog
)
return
pphy2log
,
pphyrank
,
logcnt
def
rebalance_experts
(
weight
:
torch
.
Tensor
,
num_replicas
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
enable_hierarchical
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""
num_layers
,
num_logical_experts
=
weight
.
shape
weight
=
weight
.
float
().
cpu
()
if
enable_hierarchical
:
# use hierarchical load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
else
:
# use global load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
1
,
1
,
num_gpus
)
maxlogcnt
=
logcnt
.
max
().
item
()
log2phy
:
torch
.
Tensor
=
torch
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
torch
.
int64
,
device
=
logcnt
.
device
,
)
log2phy
.
view
(
num_layers
,
-
1
).
scatter_
(
-
1
,
phy2log
*
maxlogcnt
+
phyrank
,
torch
.
arange
(
num_replicas
,
dtype
=
torch
.
int64
,
device
=
log2phy
.
device
).
expand
(
num_layers
,
-
1
),
)
return
phy2log
,
log2phy
,
logcnt
__all__
=
[
"rebalance_experts"
]
python/sglang/srt/managers/deepseek_
eplb
.py
→
python/sglang/srt/managers/
eplb_algorithms/
deepseek_
vec
.py
View file @
3ab7d9b5
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from
typing
import
Literal
,
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
torch
...
...
@@ -259,13 +258,9 @@ def rebalance_experts(
num_local_physical_experts
:
int
,
num_groups
:
Optional
[
int
],
num_nodes
:
int
,
phase
:
Literal
[
"prefill"
,
"decode"
,
"null"
]
,
enable_hierarchical
:
bool
,
):
if
(
(
phase
==
"prefill"
)
and
(
num_groups
is
not
None
)
and
(
num_groups
%
num_nodes
==
0
)
):
if
enable_hierarchical
:
return
prefill_rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
...
...
@@ -273,8 +268,9 @@ def rebalance_experts(
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
)
return
decode_rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
)
else
:
return
decode_rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
)
python/sglang/srt/managers/expert_location.py
View file @
3ab7d9b5
...
...
@@ -22,7 +22,7 @@ import torch.distributed
import
torch.nn.functional
as
F
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.managers
import
deepseek_eplb
from
sglang.srt.managers
import
eplb_algorithms
from
sglang.srt.model_loader
import
get_model_architecture
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -134,15 +134,21 @@ class ExpertLocationMetadata:
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_groups
=
model_config_for_expert_location
.
num_groups
num_nodes
=
server_args
.
nnodes
physical_to_logical_map
,
logical_to_all_physical_map
,
expert_count
=
(
deepseek_eplb
.
rebalance_experts
(
eplb_algorithms
.
rebalance_experts
(
tokens_per_expert
=
logical_count
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_physical_experts
//
common
[
"ep_size"
],
num_groups
=
model_config_for_expert_location
.
num_groups
,
num_nodes
=
server_args
.
nnodes
,
phase
=
server_args
.
disaggregation_mode
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
algorithm
=
eplb_algorithms
.
compute_algorithm
(
raw_algorithm
=
server_args
.
eplb_algorithm
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
),
)
)
...
...
python/sglang/srt/server_args.py
View file @
3ab7d9b5
...
...
@@ -175,6 +175,7 @@ class ServerArgs:
ep_dispatch_algorithm
:
Optional
[
Literal
[
"static"
,
"dynamic"
,
"fake"
]]
=
None
init_expert_location
:
str
=
"trivial"
enable_eplb
:
bool
=
False
eplb_algorithm
:
str
=
"auto"
eplb_rebalance_num_iterations
:
int
=
1000
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
...
...
@@ -1328,6 +1329,12 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Enable EPLB algorithm"
,
)
parser
.
add_argument
(
"--eplb-algorithm"
,
type
=
str
,
default
=
ServerArgs
.
eplb_algorithm
,
help
=
"Chosen EPLB algorithm"
,
)
parser
.
add_argument
(
"--eplb-rebalance-num-iterations"
,
type
=
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment