Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e23e280e
Unverified
Commit
e23e280e
authored
Sep 28, 2025
by
Shangming Cai
Committed by
GitHub
Sep 28, 2025
Browse files
Add support for topk metadata transferring for PD (#10616)
Signed-off-by:
Shangming Cai
<
csmthu@gmail.com
>
parent
51f7c6bd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
20 deletions
+60
-20
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+4
-0
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+23
-15
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+2
-0
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+27
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-2
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
e23e280e
...
@@ -614,12 +614,16 @@ class DecodeTransferQueue:
...
@@ -614,12 +614,16 @@ class DecodeTransferQueue:
output_token_logprobs_idx
,
output_token_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
output_top_logprobs_idx
,
output_topk_p
,
output_topk_index
,
output_hidden_states
,
output_hidden_states
,
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
cached_tokens
=
cached_tokens
[
0
].
item
()
decode_req
.
req
.
cached_tokens
=
cached_tokens
[
0
].
item
()
if
not
self
.
spec_algorithm
.
is_none
():
if
not
self
.
spec_algorithm
.
is_none
():
decode_req
.
req
.
output_topk_p
=
output_topk_p
decode_req
.
req
.
output_topk_index
=
output_topk_index
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
if
decode_req
.
req
.
return_logprob
:
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
decode_req
.
req
.
output_token_logprobs_val
.
append
(
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
e23e280e
...
@@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -125,25 +125,33 @@ class ScheduleBatchDisaggregationDecodeMixin:
req
.
grammar
.
finished
=
req
.
finished
()
req
.
grammar
.
finished
=
req
.
finished
()
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
# Simulate the eagle run. We add mock data to hidden states for the
# Simulate the eagle run.
# ease of implementation now meaning the first token will have acc rate
if
self
.
spec_algorithm
.
is_eagle
():
# of 0.
if
not
self
.
spec_algorithm
.
is_none
():
b
=
len
(
self
.
reqs
)
b
=
len
(
self
.
reqs
)
topk_p
=
torch
.
arange
(
topk
=
server_args
.
speculative_eagle_topk
b
*
server_args
.
speculative_eagle_topk
,
topk_p
=
torch
.
stack
(
0
,
[
-
1
,
torch
.
as_tensor
(
device
=
self
.
device
,
req
.
output_topk_p
[:
topk
],
dtype
=
torch
.
float32
,
device
=
self
.
device
,
dtype
=
torch
.
float32
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
)
topk_p
=
topk_p
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
topk_index
=
torch
.
stack
(
topk_p
/=
b
*
server_args
.
speculative_eagle_topk
[
topk_index
=
torch
.
arange
(
torch
.
as_tensor
(
b
*
server_args
.
speculative_eagle_topk
,
device
=
self
.
device
req
.
output_topk_index
[:
topk
],
device
=
self
.
device
,
dtype
=
torch
.
int64
,
)
for
req
in
self
.
reqs
],
dim
=
0
,
)
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
e23e280e
...
@@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -421,6 +421,8 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index
=
(
last_hidden_index
=
(
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
)
)
req
.
output_topk_p
=
batch
.
spec_info
.
topk_p
[
i
]
req
.
output_topk_index
=
batch
.
spec_info
.
topk_index
[
i
]
if
self
.
spec_algorithm
.
is_eagle3
():
if
self
.
spec_algorithm
.
is_eagle3
():
req
.
hidden_states_tensor
=
(
req
.
hidden_states_tensor
=
(
batch
.
spec_info
.
hidden_states
[
i
].
cpu
().
clone
()
batch
.
spec_info
.
hidden_states
[
i
].
cpu
().
clone
()
...
...
python/sglang/srt/disaggregation/utils.py
View file @
e23e280e
...
@@ -85,7 +85,7 @@ class MetadataBuffers:
...
@@ -85,7 +85,7 @@ class MetadataBuffers:
self
,
self
,
size
:
int
,
size
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
hidden_states_
dtype
:
torch
.
dtype
,
max_top_logprobs_num
:
int
=
128
,
max_top_logprobs_num
:
int
=
128
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
):
):
...
@@ -122,8 +122,15 @@ class MetadataBuffers:
...
@@ -122,8 +122,15 @@ class MetadataBuffers:
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
device
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
device
)
)
# For PD + spec decode
self
.
output_topk_p
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
)
self
.
output_topk_index
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
output_hidden_states
=
torch
.
zeros
(
self
.
output_hidden_states
=
torch
.
zeros
(
(
size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
(
size
,
hidden_size
),
dtype
=
hidden_states_
dtype
,
device
=
device
)
)
def
get_buf_infos
(
self
):
def
get_buf_infos
(
self
):
...
@@ -134,6 +141,8 @@ class MetadataBuffers:
...
@@ -134,6 +141,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
self
.
output_topk_p
.
data_ptr
(),
self
.
output_topk_index
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
]
]
data_lens
=
[
data_lens
=
[
...
@@ -143,6 +152,8 @@ class MetadataBuffers:
...
@@ -143,6 +152,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
self
.
output_topk_p
.
nbytes
,
self
.
output_topk_index
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
]
]
item_lens
=
[
item_lens
=
[
...
@@ -152,6 +163,8 @@ class MetadataBuffers:
...
@@ -152,6 +163,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
self
.
output_topk_p
[
0
].
nbytes
,
self
.
output_topk_index
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
]
]
return
ptrs
,
data_lens
,
item_lens
return
ptrs
,
data_lens
,
item_lens
...
@@ -164,6 +177,8 @@ class MetadataBuffers:
...
@@ -164,6 +177,8 @@ class MetadataBuffers:
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
self
.
output_topk_p
[
idx
],
self
.
output_topk_index
[
idx
],
self
.
output_hidden_states
[
idx
],
self
.
output_hidden_states
[
idx
],
)
)
...
@@ -193,8 +208,17 @@ class MetadataBuffers:
...
@@ -193,8 +208,17 @@ class MetadataBuffers:
]
=
torch
.
tensor
(
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
)
#
f
or PD + spec decode
#
F
or PD + spec decode
if
req
.
hidden_states_tensor
is
not
None
:
if
req
.
hidden_states_tensor
is
not
None
:
# speculative_eagle_topk should not be greater than 16 currently
topk
=
req
.
output_topk_p
.
size
(
0
)
self
.
output_topk_p
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_p
)
self
.
output_topk_index
[
req
.
metadata_buffer_index
,
:
topk
].
copy_
(
req
.
output_topk_index
)
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
req
.
hidden_states_tensor
req
.
hidden_states_tensor
)
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e23e280e
...
@@ -607,6 +607,8 @@ class Req:
...
@@ -607,6 +607,8 @@ class Req:
)
=
None
)
=
None
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states_tensor
=
None
# Note: use tensor instead of list to transfer hidden_states when PD + MTP
self
.
hidden_states_tensor
=
None
# Note: use tensor instead of list to transfer hidden_states when PD + MTP
self
.
output_topk_p
=
None
self
.
output_topk_index
=
None
# Embedding (return values)
# Embedding (return values)
self
.
embedding
=
None
self
.
embedding
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
e23e280e
...
@@ -806,7 +806,7 @@ class Scheduler(
...
@@ -806,7 +806,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
dtype
=
self
.
model_config
.
dtype
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
)
...
@@ -855,7 +855,7 @@ class Scheduler(
...
@@ -855,7 +855,7 @@ class Scheduler(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
dtype
=
self
.
model_config
.
dtype
,
hidden_states_
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment