Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d06a83fb
Unverified
Commit
d06a83fb
authored
Apr 15, 2025
by
Yuan Luo
Committed by
GitHub
Apr 15, 2025
Browse files
Support dynamic connection and TP 16 (#5351)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
5d134401
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
173 additions
and
37 deletions
+173
-37
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+173
-37
No files found.
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
d06a83fb
...
...
@@ -2,15 +2,18 @@ from __future__ import annotations
import
asyncio
import
dataclasses
import
json
import
logging
import
queue
import
random
import
struct
import
threading
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
import
requests
import
zmq
from
aiohttp
import
web
...
...
@@ -24,9 +27,21 @@ from sglang.srt.disaggregation.base.conn import (
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.utils
import
is_port_available
logger
=
logging
.
getLogger
(
__name__
)
def
find_available_ports
(
base_port
:
int
,
count
:
int
)
->
List
[
int
]:
"""Find consecutive available ports starting from base_port."""
available_ports
=
[]
current_port
=
base_port
while
len
(
available_ports
)
<
count
:
if
is_port_available
(
current_port
):
available_ports
.
append
(
current_port
)
current_port
+=
random
.
randint
(
100
,
1000
)
return
available_ports
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
...
...
@@ -65,9 +80,10 @@ class TransferKVChunk:
@
dataclasses
.
dataclass
class
TransferInfo
:
room
:
int
endpoint
:
str
decode_port
:
int
mooncake_session_id
:
str
room
:
int
dst_kv_ptrs
:
list
[
int
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
dst_aux_ptrs
:
list
[
int
]
...
...
@@ -77,25 +93,24 @@ class TransferInfo:
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
return
cls
(
endpoint
=
msg
[
0
].
decode
(
"ascii"
),
mooncake_session_id
=
msg
[
1
].
decode
(
"ascii"
),
room
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
3
])
//
8
}
Q"
,
msg
[
3
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
4
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_aux_index
=
int
(
msg
[
6
].
decode
(
"ascii"
)),
decode_port
=
int
(
msg
[
1
].
decode
(
"ascii"
)),
mooncake_session_id
=
msg
[
2
].
decode
(
"ascii"
),
room
=
int
(
msg
[
3
].
decode
(
"ascii"
)),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
4
])
//
8
}
Q"
,
msg
[
4
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
5
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
6
])
//
8
}
Q"
,
msg
[
6
])),
dst_aux_index
=
int
(
msg
[
7
].
decode
(
"ascii"
)),
)
KVSENDER_POLLING_PORT
=
17788
KVRECEIVER_POLLING_PORT
=
27788
class
MooncakeKVManager
(
BaseKVManager
):
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
):
self
.
engine
=
MooncakeTransferEngine
()
self
.
kv_args
=
args
self
.
disaggregation_mode
=
disaggregation_mode
self
.
request_status
:
Dict
[
int
,
KVPoll
]
=
{}
self
.
connection_pool
:
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
rank_port
=
None
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
self
.
register_buffer_to_engine
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
...
...
@@ -202,15 +217,10 @@ class MooncakeKVManager(BaseKVManager):
)
return
status
def
sync_status_to_decode_endpoint
(
self
,
remote
:
str
,
room
:
int
):
def
sync_status_to_decode_endpoint
(
self
,
remote
:
str
,
dst_port
:
int
,
room
:
int
):
if
":"
in
remote
:
remote
=
remote
.
split
(
":"
)[
0
]
self
.
_connect
(
"tcp://"
+
remote
+
":"
+
str
(
KVRECEIVER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
)
).
send_multipart
(
self
.
_connect
(
"tcp://"
+
remote
+
":"
+
str
(
dst_port
)).
send_multipart
(
[
str
(
room
).
encode
(
"ascii"
),
str
(
self
.
request_status
[
room
]).
encode
(
"ascii"
),
...
...
@@ -218,15 +228,16 @@ class MooncakeKVManager(BaseKVManager):
)
def
start_prefill_thread
(
self
):
sender_rank_port
=
KVSENDER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
sender_rank_port
))
# Find available port for prefill tp
self
.
rank_port
=
find_available_ports
(
20000
,
1
)[
0
]
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
self
.
rank_port
))
def
bootstrap_thread
():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while
True
:
waiting_req_bytes
=
self
.
server_socket
.
recv_multipart
()
room
=
waiting_req_bytes
[
2
].
decode
(
"ascii"
)
room
=
waiting_req_bytes
[
3
].
decode
(
"ascii"
)
if
room
==
"None"
:
continue
room
=
int
(
room
)
...
...
@@ -254,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
)
if
ret
!=
0
:
self
.
request_status
[
kv_chunk
.
room
]
=
KVPoll
.
Failed
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
continue
if
kv_chunk
.
is_last
:
...
...
@@ -268,7 +279,7 @@ class MooncakeKVManager(BaseKVManager):
self
.
request_status
[
req
.
room
]
=
(
KVPoll
.
Success
if
ret
==
0
else
KVPoll
.
Failed
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
self
.
transfer_infos
.
pop
(
req
.
room
)
except
queue
.
Empty
:
...
...
@@ -278,8 +289,8 @@ class MooncakeKVManager(BaseKVManager):
threading
.
Thread
(
target
=
transfer_thread
).
start
()
def
start_decode_thread
(
self
):
receiver_
rank_port
=
KVRECEIVER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
receiver_
rank_port
))
self
.
rank_port
=
find_available_ports
(
25000
,
1
)[
0
]
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
self
.
rank_port
))
def
decode_thread
():
while
True
:
...
...
@@ -342,6 +353,38 @@ class MooncakeKVSender(BaseKVSender):
self
.
bootstrap_room
=
bootstrap_room
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
aux_index
=
None
self
.
bootstrap_server_url
=
bootstrap_addr
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
# Register to bootstrap server
self
.
_register_to_bootstrap
()
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
url
=
f
"http://
{
self
.
bootstrap_server_url
}
/kv_route"
payload
=
{
"identity"
:
self
.
session_id
,
"role"
:
"Prefill"
,
"serve_ip"
:
self
.
kv_mgr
.
get_localhost
(),
"serve_port"
:
self
.
kv_mgr
.
rank_port
,
"tp_rank"
:
self
.
kv_mgr
.
kv_args
.
engine_rank
,
}
logger
.
info
(
f
"Register prefill server port
{
self
.
kv_mgr
.
rank_port
}
for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
try
:
response
=
requests
.
put
(
url
,
json
=
payload
)
if
response
.
status_code
==
200
:
logger
.
info
(
f
"Prefill successfully registered to bootstrap server."
)
else
:
logger
.
info
(
f
"Prefill Failed to register to bootstrap server:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
except
Exception
as
e
:
logger
.
info
(
f
"Prefill Failed to register to bootstrap server:
{
e
}
"
)
def
init
(
self
,
num_kv_indices
:
int
,
aux_index
:
Optional
[
int
]
=
None
):
self
.
num_kv_indices
=
num_kv_indices
...
...
@@ -384,14 +427,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
self
.
prefill_server_url
=
(
bootstrap_addr
.
split
(
":"
)[
0
]
+
":"
+
str
(
KVSENDER_POLLING_PORT
+
self
.
kv_mgr
.
kv_args
.
engine_rank
)
)
self
.
decode_ip
=
self
.
kv_mgr
.
get_localhost
()
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
prefill_engine_rank
=
None
self
.
decode_port
=
self
.
kv_mgr
.
rank_port
self
.
dealer_socket
=
None
def
_get_prefill_info_from_bootstrap
(
self
,
tp_rank
:
int
):
"""Fetch the prefill server port corresponding to tp_rank from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/kv_route?tp_rank=
{
tp_rank
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
prefill_info
=
response
.
json
()
return
prefill_info
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
@
cache
def
_connect
(
self
,
endpoint
:
str
):
...
...
@@ -400,6 +457,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
return
socket
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
prefill_info
=
None
logger
.
info
(
f
"Decode bootstrap addr
{
self
.
bootstrap_addr
}
."
)
if
self
.
kv_mgr
.
kv_args
.
engine_rank
not
in
self
.
kv_mgr
.
connection_pool
:
prefill_info
=
self
.
_get_prefill_info_from_bootstrap
(
self
.
kv_mgr
.
kv_args
.
engine_rank
)
if
prefill_info
is
None
:
logger
.
error
(
logger
.
error
(
f
"Could not fetch prefill server info for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
)
else
:
self
.
kv_mgr
.
connection_pool
[
self
.
kv_mgr
.
kv_args
.
engine_rank
]
=
prefill_info
else
:
prefill_info
=
self
.
kv_mgr
.
connection_pool
[
self
.
kv_mgr
.
kv_args
.
engine_rank
]
if
prefill_info
:
self
.
prefill_server_url
=
f
"
{
prefill_info
[
'serve_ip'
]
}
:
{
prefill_info
[
'serve_port'
]
}
"
logger
.
info
(
f
"Fetched prefill server info:
{
prefill_info
}
for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
self
.
handshake_prefill_server
(
kv_indices
,
aux_index
)
def
handshake_prefill_server
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
...
...
@@ -409,6 +491,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
).
send_multipart
(
[
self
.
decode_ip
.
encode
(
"ascii"
),
str
(
self
.
decode_port
).
encode
(
"ascii"
),
self
.
session_id
.
encode
(
"ascii"
),
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
packed_kv_data_ptrs
,
...
...
@@ -432,6 +515,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self
.
store
=
dict
()
self
.
lock
=
asyncio
.
Lock
()
self
.
_setup_routes
()
# prefill_engine_rank -> prefill_info
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
context
=
zmq
.
Context
()
self
.
prefill_engine_rank
=
None
# Start bootstrap server
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
...
...
@@ -442,21 +531,22 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def
_setup_routes
(
self
):
self
.
app
.
router
.
add_route
(
"*"
,
"/metadata"
,
self
.
_handle_metadata
)
self
.
app
.
router
.
add_route
(
"*"
,
"/kv_route"
,
self
.
_handle_kv_route
)
async
def
_handle_metadata
(
self
,
request
:
web
.
Request
):
key
=
request
.
query
.
get
(
"key"
,
""
)
if
request
.
method
==
"GET"
:
return
await
self
.
_handle_get
(
key
)
return
await
self
.
_handle_
metadata_
get
(
key
)
elif
request
.
method
==
"PUT"
:
return
await
self
.
_handle_put
(
key
,
request
)
return
await
self
.
_handle_
metadata_
put
(
key
,
request
)
elif
request
.
method
==
"DELETE"
:
return
await
self
.
_handle_delete
(
key
)
return
await
self
.
_handle_
metadata_
delete
(
key
)
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_get
(
self
,
key
):
async
def
_handle_
metadata_
get
(
self
,
key
):
async
with
self
.
lock
:
value
=
self
.
store
.
get
(
key
)
if
value
is
None
:
...
...
@@ -465,7 +555,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
)
return
web
.
Response
(
body
=
value
,
status
=
200
,
content_type
=
"application/json"
)
async
def
_handle_put
(
self
,
key
,
request
):
async
def
_handle_
metadata_
put
(
self
,
key
,
request
):
data
=
await
request
.
read
()
async
with
self
.
lock
:
self
.
store
[
key
]
=
data
...
...
@@ -473,7 +563,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text
=
"metadata updated"
,
status
=
200
,
content_type
=
"application/json"
)
async
def
_handle_delete
(
self
,
key
):
async
def
_handle_
metadata_
delete
(
self
,
key
):
async
with
self
.
lock
:
if
key
not
in
self
.
store
:
return
web
.
Response
(
...
...
@@ -486,6 +576,52 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text
=
"metadata deleted"
,
status
=
200
,
content_type
=
"application/json"
)
async
def
_handle_kv_route
(
self
,
request
:
web
.
Request
):
method
=
request
.
method
if
method
==
"PUT"
:
return
await
self
.
_handle_kv_route_put
(
request
)
elif
method
==
"GET"
:
return
await
self
.
_handle_kv_route_get
(
request
)
else
:
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_kv_route_put
(
self
,
request
:
web
.
Request
):
data
=
await
request
.
json
()
identity
=
data
[
"identity"
]
role
=
data
[
"role"
]
serve_ip
=
data
[
"serve_ip"
]
serve_port
=
int
(
data
[
"serve_port"
])
# Assuming serve_port is an integer
tp_rank
=
int
(
data
[
"tp_rank"
])
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
async
with
self
.
lock
:
self
.
prefill_port_table
[
tp_rank
]
=
{
"serve_ip"
:
serve_ip
,
"serve_port"
:
serve_port
}
logger
.
info
(
f
"Registered Prefill tp_rank:
{
tp_rank
}
with serve_ip:
{
serve_ip
}
and serve_port:
{
serve_port
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
async
def
_handle_kv_route_get
(
self
,
request
:
web
.
Request
):
tp_rank
=
request
.
query
.
get
(
"tp_rank"
)
if
not
tp_rank
:
return
web
.
Response
(
text
=
"Missing tp_rank"
,
status
=
400
)
try
:
tp_rank
=
int
(
tp_rank
)
except
ValueError
:
return
web
.
Response
(
text
=
"tp_rank must be int"
,
status
=
400
)
# Find corresponding prefill info
async
with
self
.
lock
:
prefill_info
=
self
.
prefill_port_table
.
get
(
tp_rank
)
if
prefill_info
is
not
None
:
return
web
.
json_response
(
prefill_info
,
status
=
200
)
else
:
return
web
.
Response
(
text
=
"Not Found"
,
status
=
404
)
def
_run_server
(
self
):
try
:
# Event Loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment