Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cba1cdbc
Unverified
Commit
cba1cdbc
authored
May 20, 2025
by
fzyzcjy
Committed by
GitHub
May 19, 2025
Browse files
Support DeepSeek EPLB algorithm with static distributions (#6387)
parent
c471d39e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
319 additions
and
8 deletions
+319
-8
python/sglang/srt/managers/deepseek_eplb.py
python/sglang/srt/managers/deepseek_eplb.py
+278
-0
python/sglang/srt/managers/expert_location.py
python/sglang/srt/managers/expert_location.py
+41
-8
No files found.
python/sglang/srt/managers/deepseek_eplb.py
0 → 100644
View file @
cba1cdbc
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from
typing
import
Literal
,
Tuple
import
torch
def
pack_groups
(
tokens_per_group
:
torch
.
Tensor
,
num_nodes
:
int
)
->
torch
.
Tensor
:
num_layers
,
num_groups
=
tokens_per_group
.
shape
assert
num_groups
%
num_nodes
==
0
groups_per_rank
=
num_groups
//
num_nodes
indices
=
tokens_per_group
.
float
().
sort
(
-
1
,
descending
=
True
).
indices
.
cpu
()
ret
=
torch
.
full_like
(
tokens_per_group
,
fill_value
=-
1
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
for
layer
in
range
(
num_layers
):
node_tokens
=
[
0
]
*
num_nodes
node_groups
=
[
0
]
*
num_nodes
for
group
in
indices
[
layer
]:
def
key_func
(
rank
:
int
)
->
int
:
if
node_groups
[
rank
]
>=
groups_per_rank
:
return
1
,
0
else
:
return
0
,
node_tokens
[
rank
]
rank
=
min
(
range
(
num_nodes
),
key
=
key_func
)
assert
node_groups
[
rank
]
<
groups_per_rank
ret
[
layer
,
group
]
=
rank
*
groups_per_rank
+
node_groups
[
rank
]
node_tokens
[
rank
]
+=
tokens_per_group
[
layer
,
group
]
node_groups
[
rank
]
+=
1
return
ret
def
make_redundant_experts_chunkwise
(
tokens_per_expert
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
num_physical_experts_per_chunk
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
num_steps
,
num_moe_layers
,
num_logical_experts
=
tokens_per_expert
.
shape
num_redundancy_experts
=
num_physical_experts
-
num_logical_experts
physical_to_logical_map
=
torch
.
empty
(
num_moe_layers
,
num_physical_experts
,
dtype
=
torch
.
int
,
device
=
tokens_per_expert
.
device
,
)
logical_to_physical_map
=
torch
.
full
(
(
num_moe_layers
,
num_logical_experts
,
num_redundancy_experts
+
1
),
-
1
,
dtype
=
torch
.
int
,
device
=
tokens_per_expert
.
device
,
)
logical_count
=
torch
.
ones
(
num_moe_layers
,
num_logical_experts
,
dtype
=
torch
.
int
,
device
=
tokens_per_expert
.
device
,
)
assert
num_physical_experts
%
num_physical_experts_per_chunk
==
0
num_chunks
=
num_physical_experts
//
num_physical_experts_per_chunk
assert
num_logical_experts
%
num_chunks
==
0
num_logical_experts_per_group
=
num_logical_experts
//
num_chunks
assert
num_redundancy_experts
%
num_chunks
==
0
num_redundancy_experts_per_group
=
num_redundancy_experts
//
num_chunks
arange_num_moe_layers_num_groups
=
torch
.
arange
(
num_moe_layers
*
num_chunks
,
dtype
=
torch
.
int
,
device
=
tokens_per_expert
.
device
)
arange_num_logical_experts
=
torch
.
arange
(
num_logical_experts
,
dtype
=
torch
.
int
,
device
=
tokens_per_expert
.
device
)
arange_num_logical_experts_per_group
=
torch
.
arange
(
num_logical_experts_per_group
,
dtype
=
torch
.
int
,
device
=
tokens_per_expert
.
device
)
arange_num_groups
=
torch
.
arange
(
num_chunks
,
dtype
=
torch
.
int
,
device
=
tokens_per_expert
.
device
)
physical_to_logical_map
.
view
(
num_moe_layers
,
num_chunks
,
num_physical_experts_per_chunk
)[:,
:,
:
num_logical_experts_per_group
]
=
arange_num_logical_experts
.
view
(
num_chunks
,
num_logical_experts_per_group
)
logical_to_physical_map
[:,
:,
0
]
=
(
arange_num_logical_experts_per_group
.
expand
(
num_chunks
,
num_logical_experts_per_group
)
+
arange_num_groups
[:,
None
]
*
num_physical_experts_per_chunk
).
view
(
num_logical_experts
)
tokens_per_expert_all_diff
=
tokens_per_expert
+
arange_num_logical_experts
*
1e-4
for
i
in
range
(
num_redundancy_experts_per_group
):
score
=
(
tokens_per_expert_all_diff
/
logical_count
)
# NOTE: Values in score must be different from each other
score1
=
tokens_per_expert
/
(
logical_count
+
1
)
score
=
score
.
view
(
num_steps
,
num_moe_layers
,
num_chunks
,
num_logical_experts_per_group
)
score1
=
score1
.
view_as
(
score
)
values
,
indices
=
score
.
max
(
-
1
,
keepdim
=
True
)
values
=
values
.
expand_as
(
score
).
contiguous
()
score
.
scatter_
(
-
1
,
indices
,
score1
.
gather
(
-
1
,
indices
))
values
.
scatter_
(
-
1
,
indices
,
score
.
max
(
-
1
,
keepdim
=
True
).
values
)
redundancy_indices
=
values
.
sum
(
0
).
argmin
(
-
1
)
physical_to_logical_map
.
view
(
num_moe_layers
,
num_chunks
,
num_physical_experts_per_chunk
)[:,
:,
num_logical_experts_per_group
+
i
]
=
(
redundancy_indices
+
arange_num_groups
*
num_logical_experts_per_group
)
redundancy_count
=
(
logical_count
.
view
(
num_moe_layers
*
num_chunks
,
num_logical_experts_per_group
)
.
gather
(
-
1
,
redundancy_indices
.
view
(
num_moe_layers
*
num_chunks
,
1
))
.
squeeze
(
1
)
)
physical_redundancy_indices
=
(
(
arange_num_groups
*
num_physical_experts_per_chunk
+
num_logical_experts_per_group
+
i
)
.
expand
(
num_moe_layers
,
num_chunks
)
.
flatten
()
)
logical_to_physical_map
.
view
(
num_moe_layers
*
num_chunks
,
num_logical_experts_per_group
,
num_redundancy_experts
+
1
,
)[
arange_num_moe_layers_num_groups
,
redundancy_indices
.
view
(
num_moe_layers
*
num_chunks
),
redundancy_count
,
]
=
physical_redundancy_indices
logical_count
.
view
(
num_moe_layers
*
num_chunks
,
num_logical_experts_per_group
)[
arange_num_moe_layers_num_groups
,
redundancy_indices
.
view
(
num_moe_layers
*
num_chunks
),
]
+=
1
if
num_local_physical_experts
>
1
:
# Load-balancing between GPUs
physical_to_logical_map_int64
=
physical_to_logical_map
.
to
(
torch
.
int64
)
counts
=
logical_count
.
gather
(
-
1
,
physical_to_logical_map_int64
)
score
=
tokens_per_expert
.
sum
(
0
).
gather
(
-
1
,
physical_to_logical_map_int64
)
score
=
score
/
counts
score
=
score
.
view
(
num_moe_layers
,
num_chunks
,
num_physical_experts_per_chunk
)
indices
=
score
.
argsort
(
-
1
,
descending
=
True
)
indices
+=
torch
.
arange
(
0
,
num_physical_experts
,
num_physical_experts_per_chunk
,
dtype
=
indices
.
dtype
,
device
=
indices
.
device
,
)[
None
,
:,
None
]
assert
num_physical_experts_per_chunk
%
num_local_physical_experts
==
0
num_local_groups
=
num_physical_experts_per_chunk
//
num_local_physical_experts
indices
=
indices
.
view
(
num_moe_layers
,
num_chunks
,
num_local_physical_experts
,
num_local_groups
)
indices
[:,
:,
1
::
2
,
:]
=
indices
[:,
:,
1
::
2
,
:].
flip
(
-
1
)
indices
=
indices
.
transpose
(
2
,
3
)
indices
=
indices
.
reshape
(
num_moe_layers
,
num_physical_experts
)
physical_to_logical_map
=
physical_to_logical_map
.
gather
(
-
1
,
indices
)
mask
=
logical_to_physical_map
==
-
1
logical_to_physical_map
[
mask
]
=
0
logical_to_physical_map
=
(
indices
.
argsort
(
-
1
)
.
gather
(
-
1
,
logical_to_physical_map
.
view
(
num_moe_layers
,
-
1
).
to
(
torch
.
int64
)
)
.
view_as
(
logical_to_physical_map
)
.
to
(
torch
.
int
)
)
logical_to_physical_map
[
mask
]
=
-
1
return
physical_to_logical_map
,
logical_to_physical_map
,
logical_count
def
decode_rebalance_experts
(
tokens_per_expert
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
):
return
make_redundant_experts_chunkwise
(
tokens_per_expert
,
num_physical_experts
,
num_local_physical_experts
,
num_physical_experts
,
)
def
prefill_rebalance_experts
(
tokens_per_expert
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
):
tokens_per_expert
=
tokens_per_expert
.
float
().
cpu
()
num_steps
,
_
,
num_logical_experts
=
tokens_per_expert
.
shape
assert
num_logical_experts
%
num_groups
==
0
group_size
=
num_logical_experts
//
num_groups
assert
num_groups
%
num_nodes
==
0
,
f
"
{
num_groups
=
}
{
num_nodes
=
}
"
tokens_per_group
=
tokens_per_expert
.
sum
(
0
).
unflatten
(
-
1
,
(
num_groups
,
-
1
)).
sum
(
-
1
)
group_perm
=
pack_groups
(
tokens_per_group
,
num_nodes
)
# [num_moe_layers, num_groups] => [num_moe_layers, num_nodes]
# log2mlog [layers, #logexp] -> [layers, #logexp]
log2mlog
=
(
(
group_perm
*
group_size
).
unsqueeze
(
-
1
)
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_perm
.
device
)
).
flatten
(
-
2
)
# mlog2log [layers, #logexp] -> [layers, #logexp], inverse of log2mlog
mlog2log
=
torch
.
empty_like
(
log2mlog
)
arange
=
torch
.
arange
(
num_logical_experts
,
dtype
=
torch
.
int64
,
device
=
mlog2log
.
device
)
mlog2log
.
scatter_
(
1
,
log2mlog
,
arange
.
expand
(
log2mlog
.
size
(
0
),
-
1
))
# tokens_per_mlog[i][j][k] = tokens_per_expert[i][j][mlog2log[j][k]]
tokens_per_mlog
=
tokens_per_expert
.
gather
(
2
,
mlog2log
.
unsqueeze
(
0
).
expand
(
num_steps
,
-
1
,
-
1
)
)
phy2mlog
,
mlog2phy
,
mlog_count
=
make_redundant_experts_chunkwise
(
tokens_per_mlog
,
num_physical_experts
,
num_local_physical_experts
,
num_physical_experts
//
num_nodes
,
)
# phy2log[i][j] = mlog2log[i][phy2mlog[i][j]]
phy2log
=
mlog2log
.
gather
(
1
,
phy2mlog
.
to
(
torch
.
int64
))
# mlog2phy: [num_moe_layers, num_logical_experts, ...]
# log2phy[i][j][k] = mlog2phy[i][log2mlog[i][j]][k]
log2phy
=
mlog2phy
.
gather
(
1
,
log2mlog
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
mlog2phy
.
size
(
-
1
)).
to
(
torch
.
int64
)
)
# log_count[i][j] = mlog_count[i][log2mlog[i][j]]
log_count
=
mlog_count
.
gather
(
1
,
log2mlog
)
return
phy2log
,
log2phy
,
log_count
def
rebalance_experts
(
tokens_per_expert
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
phase
:
Literal
[
"prefill"
,
"decode"
],
):
if
phase
==
"prefill"
:
return
prefill_rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
num_groups
=
num_groups
,
num_nodes
=
num_nodes
,
)
if
phase
==
"decode"
:
return
decode_rebalance_experts
(
tokens_per_expert
=
tokens_per_expert
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_local_physical_experts
,
)
raise
NotImplementedError
python/sglang/srt/managers/expert_location.py
View file @
cba1cdbc
...
@@ -117,6 +117,41 @@ class ExpertLocationMetadata:
...
@@ -117,6 +117,41 @@ class ExpertLocationMetadata:
logical_to_all_physical_map
=
logical_to_all_physical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
)
)
@
staticmethod
def
init_by_eplb
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
,
logical_count
:
torch
.
Tensor
):
if
not
isinstance
(
logical_count
,
torch
.
Tensor
):
logical_count
=
torch
.
tensor
(
logical_count
)
if
len
(
logical_count
.
shape
)
==
2
:
logical_count
=
logical_count
.
unsqueeze
(
0
)
logical_count
=
logical_count
.
to
(
server_args
.
device
)
common
=
ExpertLocationMetadata
.
_init_common
(
server_args
,
model_config
)
model_config_for_expert_location
=
common
[
"model_config_for_expert_location"
]
num_physical_experts
=
common
[
"num_physical_experts"
]
phase
=
server_args
.
disaggregation_mode
if
phase
==
"null"
:
phase
=
"decode"
physical_to_logical_map
,
logical_to_all_physical_map
,
expert_count
=
(
deepseek_eplb
.
rebalance_experts
(
tokens_per_expert
=
logical_count
,
num_physical_experts
=
num_physical_experts
,
num_local_physical_experts
=
num_physical_experts
//
common
[
"ep_size"
],
num_groups
=
model_config_for_expert_location
.
num_groups
,
num_nodes
=
server_args
.
nnodes
,
phase
=
phase
,
)
)
return
ExpertLocationMetadata
.
_init_raw
(
ep_size
=
common
[
"ep_size"
],
physical_to_logical_map
=
physical_to_logical_map
,
logical_to_all_physical_map
=
logical_to_all_physical_map
,
)
@
staticmethod
@
staticmethod
def
_init_common
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
def
_init_common
(
server_args
:
ServerArgs
,
model_config
:
ModelConfig
):
model_config_for_expert_location
=
(
model_config_for_expert_location
=
(
...
@@ -272,14 +307,12 @@ def compute_initial_expert_location_metadata(
...
@@ -272,14 +307,12 @@ def compute_initial_expert_location_metadata(
server_args
,
model_config
,
**
data_dict
server_args
,
model_config
,
**
data_dict
)
)
elif
"logical_count"
in
data_dict
:
elif
"logical_count"
in
data_dict
:
# TODO pr-chain: enable this later
logger
.
info
(
raise
NotImplementedError
"init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
# logger.info(
)
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
return
ExpertLocationMetadata
.
init_by_eplb
(
# )
server_args
,
model_config
,
logical_count
=
data_dict
[
"logical_count"
]
# return ExpertLocationMetadata.init_by_eplb(
)
# server_args, model_config, logical_count=data_dict["logical_count"]
# )
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"Unknown init_expert_location format (
{
list
(
data_dict
.
keys
())
=
}
)"
f
"Unknown init_expert_location format (
{
list
(
data_dict
.
keys
())
=
}
)"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment