Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
44afde82
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "b7361cc4441d7843d4799da4bf78c3654a39422e"
Unverified
Commit
44afde82
authored
Apr 14, 2025
by
Liangsheng Yin
Committed by
GitHub
Apr 14, 2025
Browse files
Fix PD disaggregation bugs (#5326)
parent
072df753
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
138 additions
and
104 deletions
+138
-104
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+128
-99
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+10
-5
No files found.
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
44afde82
from
__future__
import
annotations
import
asyncio
import
dataclasses
import
logging
import
queue
import
struct
import
threading
from
functools
import
cache
...
...
@@ -52,10 +54,38 @@ def group_concurrent_contiguous(
return
src_groups
,
dst_groups
RequestPoolType
=
Dict
[
int
,
Tuple
[
npt
.
NDArray
[
np
.
int64
],
Optional
[
int
]]]
WaitingPoolType
=
Dict
[
int
,
Tuple
[
str
,
list
[
int
],
npt
.
NDArray
[
np
.
int64
],
list
[
int
],
int
]
]
@
dataclasses
.
dataclass
class
TransferKVChunk
:
room
:
int
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
index_slice
:
slice
is_last
:
bool
prefill_aux_index
:
Optional
[
int
]
@
dataclasses
.
dataclass
class
TransferInfo
:
room
:
int
endpoint
:
str
mooncake_session_id
:
str
dst_kv_ptrs
:
list
[
int
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
dst_aux_ptrs
:
list
[
int
]
dst_aux_index
:
int
@
classmethod
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
return
cls
(
endpoint
=
msg
[
0
].
decode
(
"ascii"
),
mooncake_session_id
=
msg
[
1
].
decode
(
"ascii"
),
room
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
3
])
//
8
}
Q"
,
msg
[
3
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
4
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_aux_index
=
int
(
msg
[
6
].
decode
(
"ascii"
)),
)
KVSENDER_POLLING_PORT
=
17788
KVRECEIVER_POLLING_PORT
=
27788
...
...
@@ -65,13 +95,12 @@ class MooncakeKVManager(BaseKVManager):
self
.
engine
=
MooncakeTransferEngine
()
self
.
kv_args
=
args
self
.
disaggregation_mode
=
disaggregation_mode
self
.
request_pool
:
RequestPoolType
=
{}
self
.
request_status
:
Dict
[
int
,
KVPoll
]
=
{}
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
self
.
register_buffer_to_engine
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
waiting_pool
:
WaitingPoolType
=
{}
self
.
transfer_
event
=
threading
.
Event
()
self
.
transfer_queue
=
queue
.
Queue
()
self
.
transfer_
infos
:
Dict
[
int
,
TransferInfo
]
=
{}
self
.
start_prefill_thread
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
start_decode_thread
()
...
...
@@ -101,7 +130,7 @@ class MooncakeKVManager(BaseKVManager):
self
,
mooncake_session_id
:
str
,
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_ptrs
:
list
[
int
],
dst_
kv_
ptrs
:
list
[
int
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
],
):
layer_num
=
int
(
len
(
self
.
kv_args
.
kv_data_ptrs
)
/
2
)
...
...
@@ -114,8 +143,8 @@ class MooncakeKVManager(BaseKVManager):
prefill_value_layer_ptr
=
self
.
kv_args
.
kv_data_ptrs
[
layer_num
+
layer_id
]
value_item_len
=
self
.
kv_args
.
kv_item_lens
[
layer_num
+
layer_id
]
decode_key_layer_ptr
=
dst_ptrs
[
layer_id
]
decode_value_layer_ptr
=
dst_ptrs
[
layer_num
+
layer_id
]
decode_key_layer_ptr
=
dst_
kv_
ptrs
[
layer_id
]
decode_value_layer_ptr
=
dst_
kv_
ptrs
[
layer_num
+
layer_id
]
for
prefill_index
,
decode_index
in
zip
(
prefill_kv_blocks
,
dst_kv_blocks
):
prefill_key_addr
=
(
...
...
@@ -192,87 +221,60 @@ class MooncakeKVManager(BaseKVManager):
sender_rank_port
=
KVSENDER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
sender_rank_port
))
def
prefill_thread
():
def
bootstrap_thread
():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while
True
:
(
endpoint
,
mooncake_session_id
,
bootstrap_room
,
dst_ptrs
,
dst_kv_indices
,
dst_aux_ptrs
,
dst_aux_index
,
)
=
self
.
server_socket
.
recv_multipart
()
if
bootstrap_room
.
decode
(
"ascii"
)
==
"None"
:
waiting_req_bytes
=
self
.
server_socket
.
recv_multipart
()
room
=
waiting_req_bytes
[
2
].
decode
(
"ascii"
)
if
room
==
"None"
:
continue
endpoint
=
endpoint
.
decode
(
"ascii"
)
mooncake_session_id
=
mooncake_session_id
.
decode
(
"ascii"
)
bootstrap_room
=
int
(
bootstrap_room
.
decode
(
"ascii"
))
dst_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
dst_ptrs
)
//
8
}
Q"
,
dst_ptrs
))
dst_kv_indices
=
np
.
frombuffer
(
dst_kv_indices
,
dtype
=
np
.
int64
)
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
dst_aux_ptrs
)
//
8
}
Q"
,
dst_aux_ptrs
)
)
dst_aux_index
=
int
(
dst_aux_index
.
decode
(
"ascii"
))
self
.
waiting_pool
[
bootstrap_room
]
=
(
endpoint
,
mooncake_session_id
,
dst_ptrs
,
dst_kv_indices
,
dst_aux_ptrs
,
dst_aux_index
,
)
self
.
transfer_event
.
set
()
room
=
int
(
room
)
self
.
transfer_infos
[
room
]
=
TransferInfo
.
from_zmq
(
waiting_req_bytes
)
threading
.
Thread
(
target
=
prefill_thread
).
start
()
# NOTE: after bootstrapping we can mark the req as waiting for input
self
.
request_status
[
room
]
=
KVPoll
.
WaitingForInput
def
transfer_thread
():
# TODO: Shall we use KVPoll.Transferring state?
while
True
:
self
.
transfer_event
.
wait
()
self
.
transfer_event
.
clear
()
bootstrap_room_ready
=
self
.
request_pool
.
keys
()
bootstrap_room_request
=
self
.
waiting_pool
.
keys
()
for
room
in
list
(
bootstrap_room_request
):
if
room
not
in
list
(
bootstrap_room_ready
):
continue
status
=
KVPoll
.
Transferring
self
.
request_status
[
room
]
=
status
(
endpoint
,
mooncake_session_id
,
dst_ptrs
,
dst_kv_indices
,
dst_aux_ptrs
,
dst_aux_index
,
)
=
self
.
waiting_pool
.
pop
(
room
)
self
.
sync_status_to_decode_endpoint
(
endpoint
,
room
)
(
prefill_kv_indices
,
prefill_aux_index
,
)
=
self
.
request_pool
.
pop
(
room
)
try
:
kv_chunk
:
TransferKVChunk
=
self
.
transfer_queue
.
get
(
timeout
=
0.01
)
req
=
self
.
transfer_infos
[
kv_chunk
.
room
]
chunked_dst_kv_indice
=
req
.
dst_kv_indices
[
kv_chunk
.
index_slice
]
assert
len
(
chunked_dst_kv_indice
)
==
len
(
kv_chunk
.
prefill_kv_indices
)
ret
=
self
.
send_kvcache
(
mooncake_session_id
,
prefill_kv_indices
,
dst
_ptrs
,
dst_kv_indice
s
,
req
.
mooncake_session_id
,
kv_chunk
.
prefill_kv_indices
,
req
.
dst_kv
_ptrs
,
chunked_
dst_kv_indice
,
)
if
ret
!=
0
:
s
tatus
=
KVPoll
.
Failed
self
.
sync_status_to_decode_endpoint
(
endpoint
,
room
)
s
elf
.
request_status
[
kv_chunk
.
room
]
=
KVPoll
.
Failed
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
continue
ret
=
self
.
send_aux
(
mooncake_session_id
,
prefill_aux_index
,
dst_aux_ptrs
,
dst_aux_index
,
)
if
ret
!=
0
:
status
=
KVPoll
.
Failed
else
:
status
=
KVPoll
.
Success
self
.
request_status
[
room
]
=
status
self
.
sync_status_to_decode_endpoint
(
endpoint
,
room
)
if
kv_chunk
.
is_last
:
# Only the last chunk we need to send the aux data
ret
=
self
.
send_aux
(
req
.
mooncake_session_id
,
kv_chunk
.
prefill_aux_index
,
req
.
dst_aux_ptrs
,
req
.
dst_aux_index
,
)
self
.
request_status
[
req
.
room
]
=
(
KVPoll
.
Success
if
ret
==
0
else
KVPoll
.
Failed
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
self
.
transfer_infos
.
pop
(
req
.
room
)
except
queue
.
Empty
:
continue
threading
.
Thread
(
target
=
bootstrap_thread
).
start
()
threading
.
Thread
(
target
=
transfer_thread
).
start
()
def
start_decode_thread
(
self
):
...
...
@@ -288,29 +290,41 @@ class MooncakeKVManager(BaseKVManager):
threading
.
Thread
(
target
=
decode_thread
).
start
()
def
enqueue
_request
(
def
add_transfer
_request
(
self
,
bootstrap_room
:
int
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
],
index_slice
:
slice
,
is_last
:
bool
,
aux_index
:
Optional
[
int
]
=
None
,
):
self
.
request_pool
[
bootstrap_room
]
=
(
kv_indices
,
aux_index
)
assert
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
assert
not
is_last
or
(
is_last
and
aux_index
is
not
None
)
self
.
transfer_queue
.
put
(
TransferKVChunk
(
room
=
bootstrap_room
,
prefill_kv_indices
=
kv_indices
,
index_slice
=
index_slice
,
is_last
=
is_last
,
prefill_aux_index
=
aux_index
,
)
)
self
.
request_status
[
bootstrap_room
]
=
KVPoll
.
WaitingForInput
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
transfer_event
.
set
()
def
check_status
(
self
,
bootstrap_room
:
int
):
if
(
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
and
self
.
request_status
[
bootstrap_room
]
==
KVPoll
.
Success
):
if
bootstrap_room
in
self
.
request_pool
:
self
.
request_pool
.
pop
(
bootstrap_room
)
# TOOD: do we really need the poll()?
return
self
.
request_status
[
bootstrap_room
]
def
set_status
(
self
,
bootstrap_room
:
int
,
status
:
KVPoll
):
self
.
request_status
[
bootstrap_room
]
=
status
def
update_status
(
self
,
bootstrap_room
:
int
,
status
:
KVPoll
):
if
bootstrap_room
not
in
self
.
request_status
:
self
.
request_status
[
bootstrap_room
]
=
status
else
:
# NOTE: The prefill engine could recv bootstrapping first
self
.
request_status
[
bootstrap_room
]
=
max
(
self
.
request_status
[
bootstrap_room
],
status
)
def
get_localhost
(
self
):
return
self
.
engine
.
get_localhost
()
...
...
@@ -326,15 +340,31 @@ class MooncakeKVSender(BaseKVSender):
):
self
.
kv_mgr
=
mgr
self
.
bootstrap_room
=
bootstrap_room
self
.
kv_mgr
.
set
_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
kv_mgr
.
update
_status
(
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
aux_index
=
None
def
init
(
self
,
num_kv_indices
:
int
,
aux_index
:
Optional
[
int
]
=
None
):
self
.
aux_index
=
aux_index
self
.
num_kv_indices
=
num_kv_indices
self
.
aux_index
=
aux_index
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
]):
self
.
kv_mgr
.
enqueue_request
(
self
.
bootstrap_room
,
kv_indices
,
self
.
aux_index
)
def
send
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
index_slice
:
slice
,
is_last
:
bool
,
):
if
not
is_last
:
self
.
kv_mgr
.
add_transfer_request
(
self
.
bootstrap_room
,
kv_indices
,
index_slice
,
False
)
else
:
self
.
kv_mgr
.
add_transfer_request
(
self
.
bootstrap_room
,
kv_indices
,
index_slice
,
True
,
aux_index
=
self
.
aux_index
,
)
def
poll
(
self
)
->
KVPoll
:
return
self
.
kv_mgr
.
check_status
(
self
.
bootstrap_room
)
...
...
@@ -361,7 +391,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
self
.
decode_ip
=
self
.
kv_mgr
.
get_localhost
()
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
kv_mgr
.
set
_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
kv_mgr
.
update
_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
@
cache
def
_connect
(
self
,
endpoint
:
str
):
...
...
@@ -370,7 +400,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
return
socket
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
self
.
kv_mgr
.
enqueue_request
(
self
.
bootstrap_room
,
kv_indices
,
aux_index
)
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
44afde82
...
...
@@ -81,7 +81,7 @@ class PrefillBootstrapQueue:
self
.
gloo_group
=
gloo_group
self
.
bootstrap_port
=
bootstrap_port
def
allocate_token_id
(
self
,
idx
:
int
,
token_id
:
int
):
def
store_prefill_results
(
self
,
idx
:
int
,
token_id
:
int
):
assert
token_id
>=
0
,
f
"token_id:
{
token_id
}
is negative"
output_id_buffer
=
self
.
metadata_buffers
[
0
]
output_id_buffer
[
idx
]
=
token_id
...
...
@@ -146,7 +146,7 @@ class PrefillBootstrapQueue:
elif
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Bootstrap failed"
)
# KV.WaitingForInput
- init here
# KV.WaitingForInput
num_kv_indices
=
len
(
req
.
origin_input_ids
)
if
self
.
req_to_metadata_buffer_idx_allocator
.
available_size
()
==
0
:
break
...
...
@@ -222,6 +222,7 @@ class SchedulerDisaggregationPrefillMixin:
elif
poll
==
KVPoll
.
Success
:
# transfer done
self
.
tree_cache
.
cache_finished_req
(
req
)
# unlock the tree
req
.
finished_reason
=
FINISH_LENGTH
(
length
=
0
)
# FIXME: clean up req's data in transfer engine
done_reqs
.
append
(
req
)
elif
poll
==
KVPoll
.
Failed
:
raise
Exception
(
"Transferring failed"
)
...
...
@@ -256,14 +257,18 @@ class SchedulerDisaggregationPrefillMixin:
"""
start_idx
=
req
.
start_send_idx
end_idx
=
min
(
len
(
req
.
fill_ids
),
len
(
req
.
origin_input_ids
))
# Update next start_send_idx
req
.
start_send_idx
=
end_idx
kv_indices
=
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
start_idx
:
end_idx
]
.
cpu
()
.
numpy
()
)
req
.
start_send_idx
=
end_idx
if
token_id
is
not
None
:
self
.
disagg_prefill_pending_queue
.
allocate_token_id
(
self
.
disagg_prefill_pending_queue
.
store_prefill_results
(
req
.
metadata_buffer_index
,
token_id
)
req
.
disagg_kv_sender
.
send
(
kv_indices
)
is_last
=
token_id
is
not
None
req
.
disagg_kv_sender
.
send
(
kv_indices
,
slice
(
start_idx
,
end_idx
),
is_last
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment