Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e23e280e
Unverified
Commit
e23e280e
authored
Sep 28, 2025
by
Shangming Cai
Committed by
GitHub
Sep 28, 2025
Browse files
Add support for topk metadata transferring for PD (#10616)
Signed-off-by:
Shangming Cai
<
csmthu@gmail.com
>
parent
51f7c6bd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
20 deletions
+60
-20
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+4
-0
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+23
-15
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+2
-0
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+27
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-2
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
e23e280e
...
...
@@ -614,12 +614,16 @@ class DecodeTransferQueue:
output_token_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
output_topk_p
,
output_topk_index
,
output_hidden_states
,
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
cached_tokens
=
cached_tokens
[
0
].
item
()
if
not
self
.
spec_algorithm
.
is_none
():
decode_req
.
req
.
output_topk_p
=
output_topk_p
decode_req
.
req
.
output_topk_index
=
output_topk_index
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
e23e280e
...
...
@@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin:
req
.
grammar
.
finished
=
req
.
finished
()
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
# Simulate the eagle run. We add mock data to hidden states for the
# ease of implementation now meaning the first token will have acc rate
# of 0.
if
not
self
.
spec_algorithm
.
is_none
():
# Simulate the eagle run.
if
self
.
spec_algorithm
.
is_eagle
():
b
=
len
(
self
.
reqs
)
topk_p
=
torch
.
arange
(
b
*
server_args
.
speculative_eagle_topk
,
0
,
-
1
,
device
=
self
.
device
,
dtype
=
torch
.
float32
,
topk
=
server_args
.
speculative_eagle_topk
topk_p
=
torch
.
stack
(
[
torch
.
as_tensor
(
req
.
output_topk_p
[:
topk
],
device
=
self
.
device
,
dtype
=
torch
.
float32
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
topk_p
=
topk_p
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
topk_p
/=
b
*
server_args
.
speculative_eagle_topk
topk_index
=
torch
.
arange
(
b
*
server_args
.
speculative_eagle_topk
,
device
=
self
.
device
topk_index
=
torch
.
stack
(
[
torch
.
as_tensor
(
req
.
output_topk_index
[:
topk
],
device
=
self
.
device
,
dtype
=
torch
.
int64
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
e23e280e
...
...
@@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index
=
(
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
)
req
.
output_topk_p
=
batch
.
spec_info
.
topk_p
[
i
]
req
.
output_topk_index
=
batch
.
spec_info
.
topk_index
[
i
]
if
self
.
spec_algorithm
.
is_eagle3
():
req
.
hidden_states_tensor
=
(
batch
.
spec_info
.
hidden_states
[
i
].
cpu
().
clone
()
...
...
python/sglang/srt/disaggregation/utils.py
View file @
e23e280e
...
...
@@ -85,7 +85,7 @@ class MetadataBuffers:
self
,
size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
hidden_states_
dtype
:
torch
.
dtype
,
max_top_logprobs_num
:
int
=
128
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
):
...
...
@@ -122,8 +122,15 @@ class MetadataBuffers:
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
device
)
# For PD + spec decode
self
.
output_topk_p
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
output_topk_index
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
output_hidden_states
=
torch
.
zeros
(
(
size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
(
size
,
hidden_size
),
dtype
=
hidden_states_
dtype
,
device
=
device
)
def
get_buf_infos
(
self
):
...
...
@@ -134,6 +141,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
self
.
output_topk_p
.
data_ptr
(),
self
.
output_topk_index
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
]
data_lens
=
[
...
...
@@ -143,6 +152,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
self
.
output_topk_p
.
nbytes
,
self
.
output_topk_index
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
]
item_lens
=
[
...
...
@@ -152,6 +163,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
self
.
output_topk_p
[
0
].
nbytes
,
self
.
output_topk_index
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
]
return
ptrs
,
data_lens
,
item_lens
...
...
@@ -164,6 +177,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
self
.
output_topk_p
[
idx
],
self
.
output_topk_index
[
idx
],
self
.
output_hidden_states
[
idx
],
)
...
...
@@ -193,8 +208,17 @@ class MetadataBuffers:
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
#
f
or PD + spec decode
#
F
or PD + spec decode
if
req
.
hidden_states_tensor
is
not
None
:
# speculative_eagle_topk should not be greater than 16 currently
topk
=
req
.
output_topk_p
.
size
(
0
)
self
.
output_topk_p
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_p
)
self
.
output_topk_index
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_index
)
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
req
.
hidden_states_tensor
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e23e280e
...
...
@@ -607,6 +607,8 @@ class Req:
)
=
None
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states_tensor
=
None
# Note: use tensor instead of list to transfer hidden_states when PD + MTP
self
.
output_topk_p
=
None
self
.
output_topk_index
=
None
# Embedding (return values)
self
.
embedding
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
e23e280e
...
...
@@ -806,7 +806,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
dtype
=
self
.
model_config
.
dtype
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
...
...
@@ -855,7 +855,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
dtype
=
self
.
model_config
.
dtype
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment