Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
87068b5c
Unverified
Commit
87068b5c
authored
May 28, 2025
by
fzyzcjy
Committed by
GitHub
May 27, 2025
Browse files
Support gathering expert distribution details (#6665)
parent
a564e001
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
135 additions
and
7 deletions
+135
-7
python/sglang/srt/managers/expert_distribution.py
python/sglang/srt/managers/expert_distribution.py
+133
-4
test/srt/test_expert_distribution.py
test/srt/test_expert_distribution.py
+2
-3
No files found.
python/sglang/srt/managers/expert_distribution.py
View file @
87068b5c
...
...
@@ -18,7 +18,7 @@ from abc import ABC
from
collections
import
deque
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Type
import
einops
import
torch
...
...
@@ -293,6 +293,79 @@ class _SinglePassGatherer(ABC):
raise
NotImplementedError
class
_DetailSinglePassGatherer
(
_SinglePassGatherer
):
# DeepSeek V3 has this value; should generalize later
_TOP_K_NUM
=
8
def
__init__
(
self
,
server_args
:
ServerArgs
,
expert_location_metadata
:
"ExpertLocationMetadata"
,
rank
:
int
,
):
super
().
__init__
(
expert_location_metadata
,
rank
)
self
.
_metadata
:
Optional
[
Dict
[
str
,
Any
]]
=
None
self
.
_topk_ids_of_layer
=
torch
.
zeros
(
(
expert_location_metadata
.
num_layers
,
# TODO determine the max number
server_args
.
chunked_prefill_size
*
8
,
self
.
_TOP_K_NUM
,
),
dtype
=
torch
.
int32
,
device
=
server_args
.
device
,
)
self
.
_misc_objects
:
List
[
Dict
[
str
,
Any
]]
=
[]
assert
(
not
server_args
.
enable_two_batch_overlap
),
"DetailSinglePassGatherer does not support TBO yet"
# TODO assert shared experts fusion is disabled, o/w data is wrong
def
on_forward_pass_start
(
self
,
forward_batch
:
ForwardBatch
):
assert
self
.
_metadata
is
None
self
.
_metadata
=
dict
(
# TODO pr-chain
# rids=forward_batch.rids,
input_ids
=
forward_batch
.
input_ids
.
cpu
().
tolist
(),
positions
=
forward_batch
.
positions
.
cpu
().
tolist
(),
extend_seq_lens
=
forward_batch
.
extend_seq_lens_cpu
,
forward_mode
=
forward_batch
.
forward_mode
.
value
,
)
def
on_select_experts
(
self
,
layer_idx
:
int
,
topk_ids
:
torch
.
Tensor
):
self
.
_topk_ids_of_layer
[
layer_idx
,
:
topk_ids
.
shape
[
0
],
:]
=
topk_ids
def
on_deepep_dispatch_normal
(
self
,
layer_idx
:
int
,
local_physical_count_of_layer
:
List
[
int
],
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
num_tokens_per_expert
,
):
self
.
_misc_objects
.
append
(
dict
(
layer_id
=
layer_idx
,
num_tokens_per_rank
=
num_tokens_per_rank
.
cpu
().
tolist
(),
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
.
cpu
().
tolist
(),
num_tokens_per_expert
=
num_tokens_per_expert
.
cpu
().
tolist
(),
)
)
def
reset
(
self
):
self
.
_topk_ids_of_layer
[...]
=
-
1
self
.
_misc_objects
.
clear
()
self
.
_metadata
=
None
def
collect
(
self
)
->
Dict
:
num_tokens
=
len
(
self
.
_metadata
[
"input_ids"
])
return
dict
(
**
self
.
_metadata
,
topk_ids_of_layer
=
self
.
_topk_ids_of_layer
[:,
:
num_tokens
,
:].
clone
().
cpu
(),
misc_objects
=
self
.
_misc_objects
,
)
class
_LayerBasedSinglePassGatherer
(
_SinglePassGatherer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
...
...
@@ -438,9 +511,8 @@ class _Accumulator(ABC):
def
get_class
(
server_args
:
ServerArgs
)
->
Type
[
"_Accumulator"
]:
return
{
"stat"
:
_StatAccumulator
,
# TODO pr-chain: enable this later
# "per_pass": _DetailAccumulator,
# "per_token": _DetailAccumulator,
"per_pass"
:
_DetailAccumulator
,
"per_token"
:
_DetailAccumulator
,
}[
server_args
.
expert_distribution_recorder_mode
]
def
__init__
(
...
...
@@ -547,6 +619,63 @@ class _DequeCollection:
return
{
d
.
maxlen
:
sum
(
d
)
/
len
(
d
)
for
d
in
self
.
_dequeues
}
class
_DetailAccumulator
(
_UtilizationRateAccumulatorMixin
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_records
=
[]
def
get_single_pass_gatherer_keys
(
self
):
if
False
:
# TODO `server_args.enable_two_batch_overlap`
return
[
_SINGLE_PASS_GATHERER_KEY_PRIMARY
,
"child_a"
,
"child_b"
]
return
super
().
get_single_pass_gatherer_keys
()
def
get_single_pass_gatherer_key
(
self
,
debug_name
:
Optional
[
str
]):
if
False
:
# TODO `server_args.enable_two_batch_overlap`
return
debug_name
or
_SINGLE_PASS_GATHERER_KEY_PRIMARY
return
super
().
get_single_pass_gatherer_key
(
debug_name
)
def
append
(
self
,
forward_pass_id
:
int
,
gatherer_key
:
str
,
single_pass_data
:
Dict
,
):
super
().
append
(
forward_pass_id
,
gatherer_key
,
single_pass_data
)
def
_process_object
(
obj
):
if
isinstance
(
obj
,
torch
.
Tensor
):
return
obj
.
cpu
().
clone
()
return
obj
single_pass_data_processed
=
{
k
:
_process_object
(
v
)
for
k
,
v
in
single_pass_data
.
items
()
}
self
.
_records
.
append
(
dict
(
forward_pass_id
=
forward_pass_id
,
rank
=
self
.
_rank
,
gatherer_key
=
gatherer_key
,
**
single_pass_data_processed
,
)
)
def
reset
(
self
):
super
().
reset
()
self
.
_records
.
clear
()
def
dump
(
self
,
output_mode
:
_OutputMode
):
assert
output_mode
==
"file"
output
=
dict
(
records
=
self
.
_records
,
# NOTE: This may change during recording, so here we say it is the "last" one
last_physical_to_logical_map
=
self
.
_expert_location_metadata
.
physical_to_logical_map
,
)
_dump_to_file
(
f
"expert_distribution_recorder_
{
time
.
time
()
}
_
{
self
.
_rank
}
.pt"
,
output
)
class
_StatAccumulator
(
_UtilizationRateAccumulatorMixin
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
...
...
test/srt/test_expert_distribution.py
View file @
87068b5c
...
...
@@ -23,9 +23,8 @@ class TestExpertDistribution(CustomTestCase):
dict
(
model_path
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
,
tp_size
=
2
),
# TODO enable in next PR
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
,
mode
=
"per_pass"
),
dict
(
model_path
=
"Qwen/Qwen1.5-MoE-A2.7B"
,
mode
=
"per_token"
),
]:
with
self
.
subTest
(
info
=
info
):
self
.
_execute_core
(
**
info
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment