Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
44afde82
Unverified
Commit
44afde82
authored
Apr 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Apr 14, 2025
Browse files
Fix PD disaggregation bugs (#5326)
parent
072df753
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
138 additions
and
104 deletions
+138
-104
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+128
-99
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+10
-5
No files found.
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
44afde82
from
__future__
import
annotations
import
asyncio
import
dataclasses
import
logging
import
queue
import
struct
import
threading
from
functools
import
cache
...
...
@@ -52,10 +54,38 @@ def group_concurrent_contiguous(
return
src_groups
,
dst_groups
RequestPoolType
=
Dict
[
int
,
Tuple
[
npt
.
NDArray
[
np
.
int64
],
Optional
[
int
]]]
WaitingPoolType
=
Dict
[
int
,
Tuple
[
str
,
list
[
int
],
npt
.
NDArray
[
np
.
int64
],
list
[
int
],
int
]
]
@
dataclasses
.
dataclass
class
TransferKVChunk
:
room
:
int
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
index_slice
:
slice
is_last
:
bool
prefill_aux_index
:
Optional
[
int
]
@
dataclasses
.
dataclass
class
TransferInfo
:
room
:
int
endpoint
:
str
mooncake_session_id
:
str
dst_kv_ptrs
:
list
[
int
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
dst_aux_ptrs
:
list
[
int
]
dst_aux_index
:
int
@
classmethod
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
return
cls
(
endpoint
=
msg
[
0
].
decode
(
"ascii"
),
mooncake_session_id
=
msg
[
1
].
decode
(
"ascii"
),
room
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
3
])
//
8
}
Q"
,
msg
[
3
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
4
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_aux_index
=
int
(
msg
[
6
].
decode
(
"ascii"
)),
)
KVSENDER_POLLING_PORT
=
17788
KVRECEIVER_POLLING_PORT
=
27788
...
...
@@ -65,13 +95,12 @@ class MooncakeKVManager(BaseKVManager):
self
.
engine
=
MooncakeTransferEngine
()
self
.
kv_args
=
args
self
.
disaggregation_mode
=
disaggregation_mode
self
.
request_pool
:
RequestPoolType
=
{}
self
.
request_status
:
Dict
[
int
,
KVPoll
]
=
{}
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
self
.
register_buffer_to_engine
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
waiting_pool
:
WaitingPoolType
=
{}
self
.
transfer_
event
=
threading
.
Event
()
self
.
transfer_queue
=
queue
.
Queue
()
self
.
transfer_
infos
:
Dict
[
int
,
TransferInfo
]
=
{}
self
.
start_prefill_thread
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
start_decode_thread
()
...
...
@@ -101,7 +130,7 @@ class MooncakeKVManager(BaseKVManager):
self
,
mooncake_session_id
:
str
,
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_ptrs
:
list
[
int
],
dst_
kv_
ptrs
:
list
[
int
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
],
):
layer_num
=
int
(
len
(
self
.
kv_args
.
kv_data_ptrs
)
/
2
)
...
...
@@ -114,8 +143,8 @@ class MooncakeKVManager(BaseKVManager):
prefill_value_layer_ptr
=
self
.
kv_args
.
kv_data_ptrs
[
layer_num
+
layer_id
]
value_item_len
=
self
.
kv_args
.
kv_item_lens
[
layer_num
+
layer_id
]
decode_key_layer_ptr
=
dst_ptrs
[
layer_id
]
decode_value_layer_ptr
=
dst_ptrs
[
layer_num
+
layer_id
]
decode_key_layer_ptr
=
dst_
kv_
ptrs
[
layer_id
]
decode_value_layer_ptr
=
dst_
kv_
ptrs
[
layer_num
+
layer_id
]
for
prefill_index
,
decode_index
in
zip
(
prefill_kv_blocks
,
dst_kv_blocks
):
prefill_key_addr
=
(
...
...
@@ -192,87 +221,60 @@ class MooncakeKVManager(BaseKVManager):
sender_rank_port
=
KVSENDER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
sender_rank_port
))
def
prefill_thread
():
def
bootstrap_thread
():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while
True
:
(
endpoint
,
mooncake_session_id
,
bootstrap_room
,
dst_ptrs
,
dst_kv_indices
,
dst_aux_ptrs
,
dst_aux_index
,
)
=
self
.
server_socket
.
recv_multipart
()
if
bootstrap_room
.
decode
(
"ascii"
)
==
"None"
:
waiting_req_bytes
=
self
.
server_socket
.
recv_multipart
()
room
=
waiting_req_bytes
[
2
].
decode
(
"ascii"
)
if
room
==
"None"
:
continue
endpoint
=
endpoint
.
decode
(
"ascii"
)
mooncake_session_id
=
mooncake_session_id
.
decode
(
"ascii"
)
bootstrap_room
=
int
(
bootstrap_room
.
decode
(
"ascii"
))
dst_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
dst_ptrs
)
//
8
}
Q"
,
dst_ptrs
))
dst_kv_indices
=
np
.
frombuffer
(
dst_kv_indices
,
dtype
=
np
.
int64
)
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
dst_aux_ptrs
)
//
8
}
Q"
,
dst_aux_ptrs
)
)
dst_aux_index
=
int
(
dst_aux_index
.
decode
(
"ascii"
))
self
.
waiting_pool
[
bootstrap_room
]
=
(
endpoint
,
mooncake_session_id
,
dst_ptrs
,
dst_kv_indices
,
dst_aux_ptrs
,
dst_aux_index
,
)
self
.
transfer_event
.
set
()
room
=
int
(
room
)
self
.
transfer_infos
[
room
]
=
TransferInfo
.
from_zmq
(
waiting_req_bytes
)
threading
.
Thread
(
target
=
prefill_thread
).
start
()
# NOTE: after bootstrapping we can mark the req as waiting for input
self
.
request_status
[
room
]
=
KVPoll
.
WaitingForInput
def
transfer_thread
():
# TODO: Shall we use KVPoll.Transferring state?
while
True
:
self
.
transfer_event
.
wait
()
self
.
transfer_event
.
clear
()
bootstrap_room_ready
=
self
.
request_pool
.
keys
()
bootstrap_room_request
=
self
.
waiting_pool
.
keys
()
for
room
in
list
(
bootstrap_room_request
):
if
room
not
in
list
(
bootstrap_room_ready
):
continue
status
=
KVPoll
.
Transferring
self
.
request_status
[
room
]
=
status
(
endpoint
,
mooncake_session_id
,
dst_ptrs
,
dst_kv_indices
,
dst_aux_ptrs
,
dst_aux_index
,
)
=
self
.
waiting_pool
.
pop
(
room
)
self
.
sync_status_to_decode_endpoint
(
endpoint
,
room
)
(
prefill_kv_indices
,
prefill_aux_index
,
)
=
self
.
request_pool
.
pop
(
room
)
try
:
kv_chunk
:
TransferKVChunk
=
self
.
transfer_queue
.
get
(
timeout
=
0.01
)
req
=
self
.
transfer_infos
[
kv_chunk
.
room
]
chunked_dst_kv_indice
=
req
.
dst_kv_indices
[
kv_chunk
.
index_slice
]
assert
len
(
chunked_dst_kv_indice
)
==
len
(
kv_chunk
.
prefill_kv_indices
)
ret
=
self
.
send_kvcache
(
mooncake_session_id
,
prefill_kv_indices
,
dst
_ptrs
,
dst_kv_indice
s
,
req
.
mooncake_session_id
,
kv_chunk
.
prefill_kv_indices
,
req
.
dst_kv
_ptrs
,
chunked_
dst_kv_indice
,
)
if
ret
!=
0
:
s
tatus
=
KVPoll
.
Failed
self
.
sync_status_to_decode_endpoint
(
endpoint
,
room
)
s
elf
.
request_status
[
kv_chunk
.
room
]
=
KVPoll
.
Failed
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
continue
if
kv_chunk
.
is_last
:
# Only the last chunk we need to send the aux data
ret
=
self
.
send_aux
(
mooncake_session_id
,
prefill_aux_index
,
dst_aux_ptrs
,
dst_aux_index
,
req
.
mooncake_session_id
,
kv_chunk
.
prefill_aux_index
,
req
.
dst_aux_ptrs
,
req
.
dst_aux_index
,
)
if
ret
!=
0
:
status
=
KVPoll
.
Failed
else
:
status
=
KVPoll
.
Success
self
.
request_status
[
room
]
=
status
self
.
sync_status_to_decode_endpoint
(
endpoint
,
room
)
self
.
request_status
[
req
.
room
]
=
(
KVPoll
.
Success
if
ret
==
0
else
KVPoll
.
Failed
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
self
.
transfer_infos
.
pop
(
req
.
room
)
except
queue
.
Empty
:
continue
threading
.
Thread
(
target
=
bootstrap_thread
).
start
()
threading
.
Thread
(
target
=
transfer_thread
).
start
()
def
start_decode_thread
(
self
):
...
...
@@ -288,29 +290,41 @@ class MooncakeKVManager(BaseKVManager):
threading
.
Thread
(
target
=
decode_thread
).
start
()
def
enqueue
_request
(
def
add_transfer
_request
(
self
,
bootstrap_room
:
int
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
],
index_slice
:
slice
,
is_last
:
bool
,
aux_index
:
Optional
[
int
]
=
None
,
):
self
.
request_pool
[
bootstrap_room
]
=
(
kv_indices
,
aux_index
)
assert
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
assert
not
is_last
or
(
is_last
and
aux_index
is
not
None
)
self
.
transfer_queue
.
put
(
TransferKVChunk
(
room
=
bootstrap_room
,
prefill_kv_indices
=
kv_indices
,
index_slice
=
index_slice
,
is_last
=
is_last
,
prefill_aux_index
=
aux_index
,
)
)
self
.
request_status
[
bootstrap_room
]
=
KVPoll
.
WaitingForInput
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
transfer_event
.
set
()
def
check_status
(
self
,
bootstrap_room
:
int
):
if
(
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
and
self
.
request_status
[
bootstrap_room
]
==
KVPoll
.
Success
):
if
bootstrap_room
in
self
.
request_pool
:
self
.
request_pool
.
pop
(
bootstrap_room
)
# TOOD: do we really need the poll()?
return
self
.
request_status
[
bootstrap_room
]
def
set_status
(
self
,
bootstrap_room
:
int
,
status
:
KVPoll
):
def
update_status
(
self
,
bootstrap_room
:
int
,
status
:
KVPoll
):
if
bootstrap_room
not
in
self
.
request_status
:
self
.
request_status
[
bootstrap_room
]
=
status
else
:
# NOTE: The prefill engine could recv bootstrapping first
self
.
request_status
[
bootstrap_room
]
=
max
(
self
.
request_status
[
bootstrap_room
],
status
)
def
get_localhost
(
self
):
return
self
.
engine
.
get_localhost
()
...
...
@@ -326,15 +340,31 @@ class MooncakeKVSender(BaseKVSender):
):
self
.
kv_mgr
=
mgr
self
.
bootstrap_room
=
bootstrap_room
self
.
kv_mgr
.
set
_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
kv_mgr
.
update
_status
(
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
aux_index
=
None
def
init
(
self
,
num_kv_indices
:
int
,
aux_index
:
Optional
[
int
]
=
None
):
self
.
aux_index
=
aux_index
self
.
num_kv_indices
=
num_kv_indices
self
.
aux_index
=
aux_index
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
]):
self
.
kv_mgr
.
enqueue_request
(
self
.
bootstrap_room
,
kv_indices
,
self
.
aux_index
)
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
index_slice
:
slice
,
is_last
:
bool
,
):
if
not
is_last
:
self
.
kv_mgr
.
add_transfer_request
(
self
.
bootstrap_room
,
kv_indices
,
index_slice
,
False
)
else
:
self
.
kv_mgr
.
add_transfer_request
(
self
.
bootstrap_room
,
kv_indices
,
index_slice
,
True
,
aux_index
=
self
.
aux_index
,
)
def
poll
(
self
)
->
KVPoll
:
return
self
.
kv_mgr
.
check_status
(
self
.
bootstrap_room
)
...
...
@@ -361,7 +391,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
self
.
decode_ip
=
self
.
kv_mgr
.
get_localhost
()
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
kv_mgr
.
set
_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
kv_mgr
.
update
_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
@
cache
def
_connect
(
self
,
endpoint
:
str
):
...
...
@@ -370,7 +400,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
return
socket
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
self
.
kv_mgr
.
enqueue_request
(
self
.
bootstrap_room
,
kv_indices
,
aux_index
)
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
44afde82
...
...
@@ -81,7 +81,7 @@ class PrefillBootstrapQueue:
self
.
gloo_group
=
gloo_group
self
.
bootstrap_port
=
bootstrap_port
def
allocate_token_id
(
self
,
idx
:
int
,
token_id
:
int
):
def
store_prefill_results
(
self
,
idx
:
int
,
token_id
:
int
):
assert
token_id
>=
0
,
f
"token_id:
{
token_id
}
is negative"
output_id_buffer
=
self
.
metadata_buffers
[
0
]
output_id_buffer
[
idx
]
=
token_id
...
...
@@ -146,7 +146,7 @@ class PrefillBootstrapQueue:
elif
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Bootstrap failed"
)
# KV.WaitingForInput
- init here
# KV.WaitingForInput
num_kv_indices
=
len
(
req
.
origin_input_ids
)
if
self
.
req_to_metadata_buffer_idx_allocator
.
available_size
()
==
0
:
break
...
...
@@ -222,6 +222,7 @@ class SchedulerDisaggregationPrefillMixin:
elif
poll
==
KVPoll
.
Success
:
# transfer done
self
.
tree_cache
.
cache_finished_req
(
req
)
# unlock the tree
req
.
finished_reason
=
FINISH_LENGTH
(
length
=
0
)
# FIXME: clean up req's data in transfer engine
done_reqs
.
append
(
req
)
elif
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Transferring failed"
)
...
...
@@ -256,14 +257,18 @@ class SchedulerDisaggregationPrefillMixin:
"""
start_idx
=
req
.
start_send_idx
end_idx
=
min
(
len
(
req
.
fill_ids
),
len
(
req
.
origin_input_ids
))
# Update next start_send_idx
req
.
start_send_idx
=
end_idx
kv_indices
=
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
start_idx
:
end_idx
]
.
cpu
()
.
numpy
()
)
req
.
start_send_idx
=
end_idx
if
token_id
is
not
None
:
self
.
disagg_prefill_pending_queue
.
allocate_token_id
(
self
.
disagg_prefill_pending_queue
.
store_prefill_results
(
req
.
metadata_buffer_index
,
token_id
)
req
.
disagg_kv_sender
.
send
(
kv_indices
)
is_last
=
token_id
is
not
None
req
.
disagg_kv_sender
.
send
(
kv_indices
,
slice
(
start_idx
,
end_idx
),
is_last
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment