Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4f838c09
"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "d3a416cd99d0427cfbce8e997b058a3fbb78b37e"
Unverified
Commit
4f838c09
authored
Jun 20, 2025
by
Atream
Committed by
GitHub
Jun 19, 2025
Browse files
[PD] Transfer hidden states for mtp when disaggregation (#7242)
parent
d20a073b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
6 deletions
+43
-6
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+4
-1
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+4
-3
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+12
-0
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+14
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-2
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
4f838c09
...
@@ -541,6 +541,7 @@ class DecodeTransferQueue:
...
@@ -541,6 +541,7 @@ class DecodeTransferQueue:
self
.
metadata_buffers
=
metadata_buffers
self
.
metadata_buffers
=
metadata_buffers
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
spec_algorithm
=
scheduler
.
spec_algorithm
def
add
(
self
,
decode_req
:
DecodeRequest
)
->
None
:
def
add
(
self
,
decode_req
:
DecodeRequest
)
->
None
:
self
.
queue
.
append
(
decode_req
)
self
.
queue
.
append
(
decode_req
)
...
@@ -582,6 +583,7 @@ class DecodeTransferQueue:
...
@@ -582,6 +583,7 @@ class DecodeTransferQueue:
idx
=
decode_req
.
metadata_buffer_index
idx
=
decode_req
.
metadata_buffer_index
(
(
output_id
,
output_id
,
output_hidden_states
,
output_token_logprobs_val
,
output_token_logprobs_val
,
output_token_logprobs_idx
,
output_token_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_val
,
...
@@ -589,7 +591,8 @@ class DecodeTransferQueue:
...
@@ -589,7 +591,8 @@ class DecodeTransferQueue:
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
if
not
self
.
spec_algorithm
.
is_none
():
decode_req
.
req
.
hidden_states_tensor
=
output_hidden_states
if
decode_req
.
req
.
return_logprob
:
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
decode_req
.
req
.
output_token_logprobs_val
.
append
(
output_token_logprobs_val
[
0
].
item
()
output_token_logprobs_val
[
0
].
item
()
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
4f838c09
...
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
)
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
hidden_states_list
=
[
req
.
hidden_states_tensor
for
req
in
self
.
reqs
]
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
# local import to avoid circular import
# local import to avoid circular import
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
spec_info
=
EagleDraftInput
(
spec_info
=
EagleDraftInput
(
topk_p
=
topk_p
,
topk_p
=
topk_p
,
topk_index
=
topk_index
,
topk_index
=
topk_index
,
hidden_states
=
torch
.
ones
(
hidden_states
=
hidden_states
,
(
b
,
model_config
.
hidden_size
),
device
=
self
.
device
),
verified_id
=
self
.
output_ids
,
verified_id
=
self
.
output_ids
,
)
)
spec_info
.
prepare_for_extend
(
self
)
spec_info
.
prepare_for_extend
(
self
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
4f838c09
...
@@ -393,6 +393,8 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -393,6 +393,8 @@ class SchedulerDisaggregationPrefillMixin:
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
.
tolist
()
logits_output
.
input_token_logprobs
.
tolist
()
)
)
hidden_state_offset
=
0
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
,
strict
=
True
)
zip
(
batch
.
reqs
,
next_token_ids
,
strict
=
True
)
):
):
...
@@ -402,6 +404,16 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -402,6 +404,16 @@ class SchedulerDisaggregationPrefillMixin:
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# update the tree and lock
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# update the tree and lock
self
.
disagg_prefill_inflight_queue
.
append
(
req
)
self
.
disagg_prefill_inflight_queue
.
append
(
req
)
if
logits_output
.
hidden_states
is
not
None
:
last_hidden_index
=
(
hidden_state_offset
+
extend_input_len_per_req
[
i
]
-
1
)
req
.
hidden_states_tensor
=
(
logits_output
.
hidden_states
[
last_hidden_index
].
cpu
().
clone
()
)
hidden_state_offset
+=
extend_input_len_per_req
[
i
]
else
:
req
.
hidden_states_tensor
=
None
if
req
.
return_logprob
:
if
req
.
return_logprob
:
assert
extend_logprob_start_len_per_req
is
not
None
assert
extend_logprob_start_len_per_req
is
not
None
assert
extend_input_len_per_req
is
not
None
assert
extend_input_len_per_req
is
not
None
...
...
python/sglang/srt/disaggregation/utils.py
View file @
4f838c09
...
@@ -88,6 +88,8 @@ class MetadataBuffers:
...
@@ -88,6 +88,8 @@ class MetadataBuffers:
def
__init__
(
def
__init__
(
self
,
self
,
size
:
int
,
size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
max_top_logprobs_num
:
int
=
128
,
max_top_logprobs_num
:
int
=
128
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
custom_mem_pool
:
torch
.
cuda
.
MemPool
=
None
,
):
):
...
@@ -104,6 +106,10 @@ class MetadataBuffers:
...
@@ -104,6 +106,10 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
output_hidden_states
=
torch
.
zeros
(
(
size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
)
self
.
output_token_logprobs_val
=
torch
.
zeros
(
self
.
output_token_logprobs_val
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
device
)
)
...
@@ -120,6 +126,7 @@ class MetadataBuffers:
...
@@ -120,6 +126,7 @@ class MetadataBuffers:
def
get_buf_infos
(
self
):
def
get_buf_infos
(
self
):
ptrs
=
[
ptrs
=
[
self
.
output_ids
.
data_ptr
(),
self
.
output_ids
.
data_ptr
(),
self
.
output_hidden_states
.
data_ptr
(),
# TODO: set None to avoid transfer hidden_states when spec_algorithm is None
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
...
@@ -127,6 +134,7 @@ class MetadataBuffers:
...
@@ -127,6 +134,7 @@ class MetadataBuffers:
]
]
data_lens
=
[
data_lens
=
[
self
.
output_ids
.
nbytes
,
self
.
output_ids
.
nbytes
,
self
.
output_hidden_states
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
...
@@ -134,6 +142,7 @@ class MetadataBuffers:
...
@@ -134,6 +142,7 @@ class MetadataBuffers:
]
]
item_lens
=
[
item_lens
=
[
self
.
output_ids
[
0
].
nbytes
,
self
.
output_ids
[
0
].
nbytes
,
self
.
output_hidden_states
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
...
@@ -144,6 +153,7 @@ class MetadataBuffers:
...
@@ -144,6 +153,7 @@ class MetadataBuffers:
def
get_buf
(
self
,
idx
:
int
):
def
get_buf
(
self
,
idx
:
int
):
return
(
return
(
self
.
output_ids
[
idx
],
self
.
output_ids
[
idx
],
self
.
output_hidden_states
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
...
@@ -153,6 +163,10 @@ class MetadataBuffers:
...
@@ -153,6 +163,10 @@ class MetadataBuffers:
def
set_buf
(
self
,
req
:
Req
):
def
set_buf
(
self
,
req
:
Req
):
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
if
req
.
hidden_states_tensor
is
not
None
:
self
.
output_hidden_states
[
req
.
metadata_buffer_index
].
copy_
(
req
.
hidden_states_tensor
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
if
req
.
output_token_logprobs_val
:
# not none or empty list
if
req
.
output_token_logprobs_val
:
# not none or empty list
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
4f838c09
...
@@ -584,6 +584,7 @@ class Req:
...
@@ -584,6 +584,7 @@ class Req:
self
.
output_token_ids_logprobs_idx
self
.
output_token_ids_logprobs_idx
)
=
None
)
=
None
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states
:
List
[
List
[
float
]]
=
[]
self
.
hidden_states_tensor
=
None
# Note: use tensor instead of list to transfer hidden_states when PD + MTP
# Embedding (return values)
# Embedding (return values)
self
.
embedding
=
None
self
.
embedding
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
4f838c09
...
@@ -627,6 +627,8 @@ class Scheduler(
...
@@ -627,6 +627,8 @@ class Scheduler(
)
)
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
)
...
@@ -677,6 +679,8 @@ class Scheduler(
...
@@ -677,6 +679,8 @@ class Scheduler(
)
)
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
,
buffer_size
,
hidden_size
=
self
.
model_config
.
hf_text_config
.
hidden_size
,
dtype
=
self
.
model_config
.
dtype
,
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
custom_mem_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
().
maybe_get_custom_mem_pool
(),
)
)
...
@@ -1681,13 +1685,15 @@ class Scheduler(
...
@@ -1681,13 +1685,15 @@ class Scheduler(
# These 2 values are needed for processing the output, but the values can be
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
# we can use the correct values in output processing.
if
batch
.
return_logprob
:
if
batch
.
return_logprob
or
self
.
spec_algorithm
.
is_eagle
()
:
extend_input_len_per_req
=
[
req
.
extend_input_len
for
req
in
batch
.
reqs
]
extend_input_len_per_req
=
[
req
.
extend_input_len
for
req
in
batch
.
reqs
]
else
:
extend_input_len_per_req
=
None
if
batch
.
return_logprob
:
extend_logprob_start_len_per_req
=
[
extend_logprob_start_len_per_req
=
[
req
.
extend_logprob_start_len
for
req
in
batch
.
reqs
req
.
extend_logprob_start_len
for
req
in
batch
.
reqs
]
]
else
:
else
:
extend_input_len_per_req
=
None
extend_logprob_start_len_per_req
=
None
extend_logprob_start_len_per_req
=
None
ret
=
GenerationBatchResult
(
ret
=
GenerationBatchResult
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment