Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3ab7d9b5
"docs/en/vscode:/vscode.git/clone" did not exist on "6e2b1067bafc131c74d64f89cc01478211823bc3"
Unverified
Commit
3ab7d9b5
authored
May 29, 2025
by
fzyzcjy
Committed by
GitHub
May 29, 2025
Browse files
Support picking variants of EPLB algorithms (#6728)
parent
7e5071c9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
313 additions
and
18 deletions
+313
-18
python/sglang/srt/managers/eplb_algorithms/__init__.py
python/sglang/srt/managers/eplb_algorithms/__init__.py
+63
-0
python/sglang/srt/managers/eplb_algorithms/deepseek.py
python/sglang/srt/managers/eplb_algorithms/deepseek.py
+223
-0
python/sglang/srt/managers/eplb_algorithms/deepseek_vec.py
python/sglang/srt/managers/eplb_algorithms/deepseek_vec.py
+9
-13
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+11
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/managers/eplb_algorithms/__init__.py
0 → 100644
View file @
3ab7d9b5
from
enum
import
Enum
,
auto
from
typing
import
Optional
import
torch
from
sglang.srt.managers.eplb_algorithms
import
deepseek
,
deepseek_vec
class
EplbAlgorithm
(
Enum
):
deepseek
=
auto
()
deepseek_hierarchical
=
auto
()
deepseek_vec
=
auto
()
deepseek_vec_hierarchical
=
auto
()
# TODO may have more algorithm later
def
rebalance_experts
(
tokens_per_expert
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
num_groups
:
Optional
[
int
],
num_nodes
:
int
,
algorithm
:
EplbAlgorithm
,
):
if
algorithm
in
[
EplbAlgorithm
.
deepseek
,
EplbAlgorithm
.
deepseek_hierarchical
]:
return
deepseek
.
rebalance_experts
(
weight
=
tokens_per_expert
.
sum
(
dim
=
0
),
num_replicas
=
num_physical_experts
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
num_gpus
=
num_physical_experts
//
num_local_physical_experts
,
enable_hierarchical
=
algorithm
==
EplbAlgorithm
.
deepseek_hierarchical
,
)
if
algorithm
in
[
EplbAlgorithm
.
deepseek_vec
,
EplbAlgorithm
.
deepseek_vec_hierarchical
,
]:
return
deepseek_vec
.
rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
enable_hierarchical
=
algorithm
==
EplbAlgorithm
.
deepseek_vec_hierarchical
,
)
raise
NotImplementedError
def
compute_algorithm
(
raw_algorithm
:
str
,
num_groups
:
Optional
[
int
],
num_nodes
:
int
,
)
->
EplbAlgorithm
:
if
raw_algorithm
!=
"auto"
:
return
EplbAlgorithm
[
raw_algorithm
]
# TODO test on real scenarios and know which ones perform better
if
(
num_groups
is
not
None
)
and
(
num_groups
%
num_nodes
==
0
):
return
EplbAlgorithm
.
deepseek_hierarchical
else
:
return
EplbAlgorithm
.
deepseek
python/sglang/srt/managers/eplb_algorithms/deepseek.py
0 → 100644
View file @
3ab7d9b5
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from
typing
import
Tuple
import
torch
from
sglang.srt.utils
import
get_bool_env_var
def
balanced_packing
(
weight
:
torch
.
Tensor
,
num_packs
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers
,
num_groups
=
weight
.
shape
assert
num_groups
%
num_packs
==
0
groups_per_pack
=
num_groups
//
num_packs
if
groups_per_pack
==
1
:
pack_index
=
torch
.
arange
(
weight
.
size
(
-
1
),
dtype
=
torch
.
int64
,
device
=
weight
.
device
).
expand
(
weight
.
shape
)
rank_in_pack
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int64
)
return
pack_index
,
rank_in_pack
indices
=
weight
.
float
().
sort
(
-
1
,
descending
=
True
).
indices
.
cpu
()
pack_index
=
torch
.
full_like
(
weight
,
fill_value
=-
1
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
rank_in_pack
=
torch
.
full_like
(
pack_index
,
fill_value
=-
1
)
for
i
in
range
(
num_layers
):
pack_weights
=
[
0
]
*
num_packs
pack_items
=
[
0
]
*
num_packs
for
group
in
indices
[
i
]:
pack
=
min
(
(
i
for
i
in
range
(
num_packs
)
if
pack_items
[
i
]
<
groups_per_pack
),
key
=
pack_weights
.
__getitem__
,
)
assert
pack_items
[
pack
]
<
groups_per_pack
pack_index
[
i
,
group
]
=
pack
rank_in_pack
[
i
,
group
]
=
pack_items
[
pack
]
pack_weights
[
pack
]
+=
weight
[
i
,
group
]
pack_items
[
pack
]
+=
1
return
pack_index
,
rank_in_pack
def
replicate_experts
(
weight
:
torch
.
Tensor
,
num_phy
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n
,
num_log
=
weight
.
shape
num_redundant
=
num_phy
-
num_log
assert
num_redundant
>=
0
device
=
weight
.
device
phy2log
=
torch
.
arange
(
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
).
repeat
(
n
,
1
)
rank
=
torch
.
zeros
(
n
,
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
)
logcnt
=
torch
.
ones
(
n
,
num_log
,
dtype
=
torch
.
int64
,
device
=
device
)
arangen
=
torch
.
arange
(
n
,
dtype
=
torch
.
int64
,
device
=
device
)
for
i
in
range
(
num_log
,
num_phy
):
redundant_indices
=
(
weight
/
logcnt
).
max
(
dim
=-
1
).
indices
phy2log
[:,
i
]
=
redundant_indices
rank
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
logcnt
[
arangen
,
redundant_indices
]
+=
1
return
phy2log
,
rank
,
logcnt
def
rebalance_experts_hierarchical
(
weight
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers
,
num_logical_experts
=
weight
.
shape
assert
num_logical_experts
%
num_groups
==
0
group_size
=
num_logical_experts
//
num_groups
assert
num_groups
%
num_nodes
==
0
groups_per_node
=
num_groups
//
num_nodes
assert
num_gpus
%
num_nodes
==
0
assert
num_physical_experts
%
num_gpus
==
0
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
def
inverse
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
inv
=
torch
.
empty_like
(
perm
)
inv
.
scatter_
(
1
,
perm
,
torch
.
arange
(
perm
.
size
(
1
),
dtype
=
torch
.
int64
,
device
=
perm
.
device
).
expand
(
perm
.
shape
),
)
return
inv
# Step 1: pack groups to nodes
tokens_per_group
=
weight
.
unflatten
(
-
1
,
(
num_groups
,
group_size
)).
sum
(
-
1
)
group_pack_index
,
group_rank_in_pack
=
balanced_packing
(
tokens_per_group
,
num_nodes
)
log2mlog
=
(
(
(
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)
*
group_size
).
unsqueeze
(
-
1
)
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_pack_index
.
device
)
).
flatten
(
-
2
)
mlog2log
=
inverse
(
log2mlog
)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog
=
weight
.
gather
(
-
1
,
mlog2log
).
view
(
-
1
,
num_logical_experts
//
num_nodes
)
phy2mlog
,
phyrank
,
mlogcnt
=
replicate_experts
(
tokens_per_mlog
,
num_physical_experts
//
num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy
=
(
tokens_per_mlog
/
mlogcnt
).
gather
(
-
1
,
phy2mlog
)
pack_index
,
rank_in_pack
=
balanced_packing
(
tokens_per_phy
,
num_gpus
//
num_nodes
)
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
pphy2phy
=
inverse
(
phy2pphy
)
pphy2mlog
=
phy2mlog
.
gather
(
-
1
,
pphy2phy
)
# [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog
=
(
pphy2mlog
.
view
(
num_layers
,
num_nodes
,
-
1
)
+
torch
.
arange
(
0
,
num_logical_experts
,
num_logical_experts
//
num_nodes
,
device
=
group_pack_index
.
device
,
).
view
(
1
,
-
1
,
1
)
).
flatten
(
-
2
)
pphy2log
=
mlog2log
.
gather
(
-
1
,
pphy2mlog
)
pphyrank
=
phyrank
.
gather
(
-
1
,
pphy2phy
).
view
(
num_layers
,
-
1
)
logcnt
=
mlogcnt
.
view
(
num_layers
,
-
1
).
gather
(
-
1
,
log2mlog
)
return
pphy2log
,
pphyrank
,
logcnt
def
rebalance_experts
(
weight
:
torch
.
Tensor
,
num_replicas
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
enable_hierarchical
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""
num_layers
,
num_logical_experts
=
weight
.
shape
weight
=
weight
.
float
().
cpu
()
if
enable_hierarchical
:
# use hierarchical load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
else
:
# use global load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
1
,
1
,
num_gpus
)
maxlogcnt
=
logcnt
.
max
().
item
()
log2phy
:
torch
.
Tensor
=
torch
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
torch
.
int64
,
device
=
logcnt
.
device
,
)
log2phy
.
view
(
num_layers
,
-
1
).
scatter_
(
-
1
,
phy2log
*
maxlogcnt
+
phyrank
,
torch
.
arange
(
num_replicas
,
dtype
=
torch
.
int64
,
device
=
log2phy
.
device
).
expand
(
num_layers
,
-
1
),
)
return
phy2log
,
log2phy
,
logcnt
__all__
=
[
"rebalance_experts"
]
python/sglang/srt/managers/deepseek_
eplb
.py
→
python/sglang/srt/managers/
eplb_algorithms/
deepseek_
vec
.py
View file @
3ab7d9b5
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from
typing
import
Optional
,
Tuple
from
typing
import
Literal
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -259,13 +258,9 @@ def rebalance_experts(
...
@@ -259,13 +258,9 @@ def rebalance_experts(
num_local_physical_experts
:
int
,
num_local_physical_experts
:
int
,
num_groups
:
Optional
[
int
],
num_groups
:
Optional
[
int
],
num_nodes
:
int
,
num_nodes
:
int
,
phase
:
Literal
[
"prefill"
,
"decode"
,
"null"
]
,
enable_hierarchical
:
bool
,
):
):
if
(
if
enable_hierarchical
:
(
phase
==
"prefill"
)
and
(
num_groups
is
not
None
)
and
(
num_groups
%
num_nodes
==
0
)
):
return
prefill_rebalance_experts
(
return
prefill_rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_physical_experts
=
num_physical_experts
,
...
@@ -273,6 +268,7 @@ def rebalance_experts(
...
@@ -273,6 +268,7 @@ def rebalance_experts(
num_groups
=
num_groups
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
num_nodes
=
num_nodes
,
)
)
else
:
return
decode_rebalance_experts
(
return
decode_rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_physical_experts
=
num_physical_experts
,
...
...
python/sglang/srt/managers/expert_location.py
View file @
3ab7d9b5
...
@@ -22,7 +22,7 @@ import torch.distributed
...
@@ -22,7 +22,7 @@ import torch.distributed
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.managers
import
deepseek_eplb
from
sglang.srt.managers
import
eplb_algorithms
from
sglang.srt.model_loader
import
get_model_architecture
from
sglang.srt.model_loader
import
get_model_architecture
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -134,15 +134,21 @@ class ExpertLocationMetadata:
...
@@ -134,15 +134,21 @@ class ExpertLocationMetadata:
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
num_groups
=
model_config_for_expert_location
.
num_groups
num_nodes
=
server_args
.
nnodes
physical_to_logical_map
,
logical_to_all_physical_map
,
expert_count
=
(
physical_to_logical_map
,
logical_to_all_physical_map
,
expert_count
=
(
deepseek_eplb
.
rebalance_experts
(
eplb_algorithms
.
rebalance_experts
(
tokens_per_expert
=
logical_count
,
tokens_per_expert
=
logical_count
,
num_physical_experts
=
num_physical_experts
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_physical_experts
//
common
[
"ep_size"
],
num_local_physical_experts
=
num_physical_experts
//
common
[
"ep_size"
],
num_groups
=
model_config_for_expert_location
.
num_groups
,
num_groups
=
num_groups
,
num_nodes
=
server_args
.
nnodes
,
num_nodes
=
num_nodes
,
phase
=
server_args
.
disaggregation_mode
,
algorithm
=
eplb_algorithms
.
compute_algorithm
(
raw_algorithm
=
server_args
.
eplb_algorithm
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
),
)
)
)
)
...
...
python/sglang/srt/server_args.py
View file @
3ab7d9b5
...
@@ -175,6 +175,7 @@ class ServerArgs:
...
@@ -175,6 +175,7 @@ class ServerArgs:
ep_dispatch_algorithm
:
Optional
[
Literal
[
"static"
,
"dynamic"
,
"fake"
]]
=
None
ep_dispatch_algorithm
:
Optional
[
Literal
[
"static"
,
"dynamic"
,
"fake"
]]
=
None
init_expert_location
:
str
=
"trivial"
init_expert_location
:
str
=
"trivial"
enable_eplb
:
bool
=
False
enable_eplb
:
bool
=
False
eplb_algorithm
:
str
=
"auto"
eplb_rebalance_num_iterations
:
int
=
1000
eplb_rebalance_num_iterations
:
int
=
1000
expert_distribution_recorder_mode
:
Optional
[
expert_distribution_recorder_mode
:
Optional
[
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
Literal
[
"stat"
,
"per_pass"
,
"per_token"
]
...
@@ -1328,6 +1329,12 @@ class ServerArgs:
...
@@ -1328,6 +1329,12 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable EPLB algorithm"
,
help
=
"Enable EPLB algorithm"
,
)
)
parser
.
add_argument
(
"--eplb-algorithm"
,
type
=
str
,
default
=
ServerArgs
.
eplb_algorithm
,
help
=
"Chosen EPLB algorithm"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--eplb-rebalance-num-iterations"
,
"--eplb-rebalance-num-iterations"
,
type
=
int
,
type
=
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment