Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d06a83fb
Unverified
Commit
d06a83fb
authored
Apr 15, 2025
by
Yuan Luo
Committed by
GitHub
Apr 15, 2025
Browse files
Support dynamic connection and TP 16 (#5351)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
5d134401
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
173 additions
and
37 deletions
+173
-37
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+173
-37
No files found.
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
d06a83fb
...
@@ -2,15 +2,18 @@ from __future__ import annotations
...
@@ -2,15 +2,18 @@ from __future__ import annotations
import
asyncio
import
asyncio
import
dataclasses
import
dataclasses
import
json
import
logging
import
logging
import
queue
import
queue
import
random
import
struct
import
struct
import
threading
import
threading
from
functools
import
cache
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
import
requests
import
zmq
import
zmq
from
aiohttp
import
web
from
aiohttp
import
web
...
@@ -24,9 +27,21 @@ from sglang.srt.disaggregation.base.conn import (
...
@@ -24,9 +27,21 @@ from sglang.srt.disaggregation.base.conn import (
)
)
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.mooncake.transfer_engine
import
MooncakeTransferEngine
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.utils
import
is_port_available
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
find_available_ports
(
base_port
:
int
,
count
:
int
)
->
List
[
int
]:
"""Find consecutive available ports starting from base_port."""
available_ports
=
[]
current_port
=
base_port
while
len
(
available_ports
)
<
count
:
if
is_port_available
(
current_port
):
available_ports
.
append
(
current_port
)
current_port
+=
random
.
randint
(
100
,
1000
)
return
available_ports
def
group_concurrent_contiguous
(
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
...
@@ -65,9 +80,10 @@ class TransferKVChunk:
...
@@ -65,9 +80,10 @@ class TransferKVChunk:
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TransferInfo
:
class
TransferInfo
:
room
:
int
endpoint
:
str
endpoint
:
str
decode_port
:
int
mooncake_session_id
:
str
mooncake_session_id
:
str
room
:
int
dst_kv_ptrs
:
list
[
int
]
dst_kv_ptrs
:
list
[
int
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int64
]
dst_aux_ptrs
:
list
[
int
]
dst_aux_ptrs
:
list
[
int
]
...
@@ -77,25 +93,24 @@ class TransferInfo:
...
@@ -77,25 +93,24 @@ class TransferInfo:
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
return
cls
(
return
cls
(
endpoint
=
msg
[
0
].
decode
(
"ascii"
),
endpoint
=
msg
[
0
].
decode
(
"ascii"
),
mooncake_session_id
=
msg
[
1
].
decode
(
"ascii"
),
decode_port
=
int
(
msg
[
1
].
decode
(
"ascii"
)),
room
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
mooncake_session_id
=
msg
[
2
].
decode
(
"ascii"
),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
3
])
//
8
}
Q"
,
msg
[
3
])),
room
=
int
(
msg
[
3
].
decode
(
"ascii"
)),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
4
],
dtype
=
np
.
int64
),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
4
])
//
8
}
Q"
,
msg
[
4
])),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
5
],
dtype
=
np
.
int64
),
dst_aux_index
=
int
(
msg
[
6
].
decode
(
"ascii"
)),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
6
])
//
8
}
Q"
,
msg
[
6
])),
dst_aux_index
=
int
(
msg
[
7
].
decode
(
"ascii"
)),
)
)
KVSENDER_POLLING_PORT
=
17788
KVRECEIVER_POLLING_PORT
=
27788
class
MooncakeKVManager
(
BaseKVManager
):
class
MooncakeKVManager
(
BaseKVManager
):
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
):
def
__init__
(
self
,
args
:
KVArgs
,
disaggregation_mode
:
DisaggregationMode
):
self
.
engine
=
MooncakeTransferEngine
()
self
.
engine
=
MooncakeTransferEngine
()
self
.
kv_args
=
args
self
.
kv_args
=
args
self
.
disaggregation_mode
=
disaggregation_mode
self
.
disaggregation_mode
=
disaggregation_mode
self
.
request_status
:
Dict
[
int
,
KVPoll
]
=
{}
self
.
request_status
:
Dict
[
int
,
KVPoll
]
=
{}
self
.
connection_pool
:
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
rank_port
=
None
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
self
.
register_buffer_to_engine
()
self
.
register_buffer_to_engine
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
...
@@ -202,15 +217,10 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -202,15 +217,10 @@ class MooncakeKVManager(BaseKVManager):
)
)
return
status
return
status
def
sync_status_to_decode_endpoint
(
self
,
remote
:
str
,
room
:
int
):
def
sync_status_to_decode_endpoint
(
self
,
remote
:
str
,
dst_port
:
int
,
room
:
int
):
if
":"
in
remote
:
if
":"
in
remote
:
remote
=
remote
.
split
(
":"
)[
0
]
remote
=
remote
.
split
(
":"
)[
0
]
self
.
_connect
(
self
.
_connect
(
"tcp://"
+
remote
+
":"
+
str
(
dst_port
)).
send_multipart
(
"tcp://"
+
remote
+
":"
+
str
(
KVRECEIVER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
)
).
send_multipart
(
[
[
str
(
room
).
encode
(
"ascii"
),
str
(
room
).
encode
(
"ascii"
),
str
(
self
.
request_status
[
room
]).
encode
(
"ascii"
),
str
(
self
.
request_status
[
room
]).
encode
(
"ascii"
),
...
@@ -218,15 +228,16 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -218,15 +228,16 @@ class MooncakeKVManager(BaseKVManager):
)
)
def
start_prefill_thread
(
self
):
def
start_prefill_thread
(
self
):
sender_rank_port
=
KVSENDER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
# Find available port for prefill tp
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
sender_rank_port
))
self
.
rank_port
=
find_available_ports
(
20000
,
1
)[
0
]
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
self
.
rank_port
))
def
bootstrap_thread
():
def
bootstrap_thread
():
"""This thread recvs pre-alloc notification from the decode engine"""
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while
True
:
while
True
:
waiting_req_bytes
=
self
.
server_socket
.
recv_multipart
()
waiting_req_bytes
=
self
.
server_socket
.
recv_multipart
()
room
=
waiting_req_bytes
[
2
].
decode
(
"ascii"
)
room
=
waiting_req_bytes
[
3
].
decode
(
"ascii"
)
if
room
==
"None"
:
if
room
==
"None"
:
continue
continue
room
=
int
(
room
)
room
=
int
(
room
)
...
@@ -254,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -254,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
)
)
if
ret
!=
0
:
if
ret
!=
0
:
self
.
request_status
[
kv_chunk
.
room
]
=
KVPoll
.
Failed
self
.
request_status
[
kv_chunk
.
room
]
=
KVPoll
.
Failed
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
continue
continue
if
kv_chunk
.
is_last
:
if
kv_chunk
.
is_last
:
...
@@ -268,7 +279,7 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -268,7 +279,7 @@ class MooncakeKVManager(BaseKVManager):
self
.
request_status
[
req
.
room
]
=
(
self
.
request_status
[
req
.
room
]
=
(
KVPoll
.
Success
if
ret
==
0
else
KVPoll
.
Failed
KVPoll
.
Success
if
ret
==
0
else
KVPoll
.
Failed
)
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
room
)
self
.
sync_status_to_decode_endpoint
(
req
.
endpoint
,
req
.
dst_port
,
req
.
room
)
self
.
transfer_infos
.
pop
(
req
.
room
)
self
.
transfer_infos
.
pop
(
req
.
room
)
except
queue
.
Empty
:
except
queue
.
Empty
:
...
@@ -278,8 +289,8 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -278,8 +289,8 @@ class MooncakeKVManager(BaseKVManager):
threading
.
Thread
(
target
=
transfer_thread
).
start
()
threading
.
Thread
(
target
=
transfer_thread
).
start
()
def
start_decode_thread
(
self
):
def
start_decode_thread
(
self
):
receiver_
rank_port
=
KVRECEIVER_POLLING_PORT
+
self
.
kv_args
.
engine_rank
self
.
rank_port
=
find_available_ports
(
25000
,
1
)[
0
]
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
receiver_
rank_port
))
self
.
server_socket
.
bind
(
"tcp://*:"
+
str
(
self
.
rank_port
))
def
decode_thread
():
def
decode_thread
():
while
True
:
while
True
:
...
@@ -342,6 +353,38 @@ class MooncakeKVSender(BaseKVSender):
...
@@ -342,6 +353,38 @@ class MooncakeKVSender(BaseKVSender):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
aux_index
=
None
self
.
aux_index
=
None
self
.
bootstrap_server_url
=
bootstrap_addr
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
# Register to bootstrap server
self
.
_register_to_bootstrap
()
def
_register_to_bootstrap
(
self
):
"""Register KVSender to bootstrap server via HTTP POST."""
url
=
f
"http://
{
self
.
bootstrap_server_url
}
/kv_route"
payload
=
{
"identity"
:
self
.
session_id
,
"role"
:
"Prefill"
,
"serve_ip"
:
self
.
kv_mgr
.
get_localhost
(),
"serve_port"
:
self
.
kv_mgr
.
rank_port
,
"tp_rank"
:
self
.
kv_mgr
.
kv_args
.
engine_rank
,
}
logger
.
info
(
f
"Register prefill server port
{
self
.
kv_mgr
.
rank_port
}
for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
try
:
response
=
requests
.
put
(
url
,
json
=
payload
)
if
response
.
status_code
==
200
:
logger
.
info
(
f
"Prefill successfully registered to bootstrap server."
)
else
:
logger
.
info
(
f
"Prefill Failed to register to bootstrap server:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
except
Exception
as
e
:
logger
.
info
(
f
"Prefill Failed to register to bootstrap server:
{
e
}
"
)
def
init
(
self
,
num_kv_indices
:
int
,
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
num_kv_indices
:
int
,
aux_index
:
Optional
[
int
]
=
None
):
self
.
num_kv_indices
=
num_kv_indices
self
.
num_kv_indices
=
num_kv_indices
...
@@ -384,14 +427,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -384,14 +427,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
self
.
kv_mgr
=
mgr
self
.
prefill_server_url
=
(
bootstrap_addr
.
split
(
":"
)[
0
]
+
":"
+
str
(
KVSENDER_POLLING_PORT
+
self
.
kv_mgr
.
kv_args
.
engine_rank
)
)
self
.
decode_ip
=
self
.
kv_mgr
.
get_localhost
()
self
.
decode_ip
=
self
.
kv_mgr
.
get_localhost
()
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
kv_mgr
.
update_status
(
bootstrap_room
,
KVPoll
.
WaitingForInput
)
self
.
prefill_engine_rank
=
None
self
.
decode_port
=
self
.
kv_mgr
.
rank_port
self
.
dealer_socket
=
None
def
_get_prefill_info_from_bootstrap
(
self
,
tp_rank
:
int
):
"""Fetch the prefill server port corresponding to tp_rank from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/kv_route?tp_rank=
{
tp_rank
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
prefill_info
=
response
.
json
()
return
prefill_info
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
@
cache
@
cache
def
_connect
(
self
,
endpoint
:
str
):
def
_connect
(
self
,
endpoint
:
str
):
...
@@ -400,6 +457,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -400,6 +457,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
return
socket
return
socket
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
prefill_info
=
None
logger
.
info
(
f
"Decode bootstrap addr
{
self
.
bootstrap_addr
}
."
)
if
self
.
kv_mgr
.
kv_args
.
engine_rank
not
in
self
.
kv_mgr
.
connection_pool
:
prefill_info
=
self
.
_get_prefill_info_from_bootstrap
(
self
.
kv_mgr
.
kv_args
.
engine_rank
)
if
prefill_info
is
None
:
logger
.
error
(
logger
.
error
(
f
"Could not fetch prefill server info for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
)
else
:
self
.
kv_mgr
.
connection_pool
[
self
.
kv_mgr
.
kv_args
.
engine_rank
]
=
prefill_info
else
:
prefill_info
=
self
.
kv_mgr
.
connection_pool
[
self
.
kv_mgr
.
kv_args
.
engine_rank
]
if
prefill_info
:
self
.
prefill_server_url
=
f
"
{
prefill_info
[
'serve_ip'
]
}
:
{
prefill_info
[
'serve_port'
]
}
"
logger
.
info
(
f
"Fetched prefill server info:
{
prefill_info
}
for tp_rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
self
.
handshake_prefill_server
(
kv_indices
,
aux_index
)
def
handshake_prefill_server
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
packed_kv_data_ptrs
=
b
""
.
join
(
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
)
...
@@ -409,6 +491,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -409,6 +491,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
).
send_multipart
(
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
).
send_multipart
(
[
[
self
.
decode_ip
.
encode
(
"ascii"
),
self
.
decode_ip
.
encode
(
"ascii"
),
str
(
self
.
decode_port
).
encode
(
"ascii"
),
self
.
session_id
.
encode
(
"ascii"
),
self
.
session_id
.
encode
(
"ascii"
),
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
packed_kv_data_ptrs
,
packed_kv_data_ptrs
,
...
@@ -432,6 +515,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
...
@@ -432,6 +515,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self
.
store
=
dict
()
self
.
store
=
dict
()
self
.
lock
=
asyncio
.
Lock
()
self
.
lock
=
asyncio
.
Lock
()
self
.
_setup_routes
()
self
.
_setup_routes
()
# prefill_engine_rank -> prefill_info
self
.
prefill_port_table
:
Dict
[
int
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
self
.
context
=
zmq
.
Context
()
self
.
prefill_engine_rank
=
None
# Start bootstrap server
# Start bootstrap server
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
self
.
thread
=
threading
.
Thread
(
target
=
self
.
_run_server
,
daemon
=
True
)
...
@@ -442,21 +531,22 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
...
@@ -442,21 +531,22 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def
_setup_routes
(
self
):
def
_setup_routes
(
self
):
self
.
app
.
router
.
add_route
(
"*"
,
"/metadata"
,
self
.
_handle_metadata
)
self
.
app
.
router
.
add_route
(
"*"
,
"/metadata"
,
self
.
_handle_metadata
)
self
.
app
.
router
.
add_route
(
"*"
,
"/kv_route"
,
self
.
_handle_kv_route
)
async
def
_handle_metadata
(
self
,
request
:
web
.
Request
):
async
def
_handle_metadata
(
self
,
request
:
web
.
Request
):
key
=
request
.
query
.
get
(
"key"
,
""
)
key
=
request
.
query
.
get
(
"key"
,
""
)
if
request
.
method
==
"GET"
:
if
request
.
method
==
"GET"
:
return
await
self
.
_handle_get
(
key
)
return
await
self
.
_handle_
metadata_
get
(
key
)
elif
request
.
method
==
"PUT"
:
elif
request
.
method
==
"PUT"
:
return
await
self
.
_handle_put
(
key
,
request
)
return
await
self
.
_handle_
metadata_
put
(
key
,
request
)
elif
request
.
method
==
"DELETE"
:
elif
request
.
method
==
"DELETE"
:
return
await
self
.
_handle_delete
(
key
)
return
await
self
.
_handle_
metadata_
delete
(
key
)
return
web
.
Response
(
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
)
async
def
_handle_get
(
self
,
key
):
async
def
_handle_
metadata_
get
(
self
,
key
):
async
with
self
.
lock
:
async
with
self
.
lock
:
value
=
self
.
store
.
get
(
key
)
value
=
self
.
store
.
get
(
key
)
if
value
is
None
:
if
value
is
None
:
...
@@ -465,7 +555,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
...
@@ -465,7 +555,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
)
)
return
web
.
Response
(
body
=
value
,
status
=
200
,
content_type
=
"application/json"
)
return
web
.
Response
(
body
=
value
,
status
=
200
,
content_type
=
"application/json"
)
async
def
_handle_put
(
self
,
key
,
request
):
async
def
_handle_
metadata_
put
(
self
,
key
,
request
):
data
=
await
request
.
read
()
data
=
await
request
.
read
()
async
with
self
.
lock
:
async
with
self
.
lock
:
self
.
store
[
key
]
=
data
self
.
store
[
key
]
=
data
...
@@ -473,7 +563,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
...
@@ -473,7 +563,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text
=
"metadata updated"
,
status
=
200
,
content_type
=
"application/json"
text
=
"metadata updated"
,
status
=
200
,
content_type
=
"application/json"
)
)
async
def
_handle_delete
(
self
,
key
):
async
def
_handle_
metadata_
delete
(
self
,
key
):
async
with
self
.
lock
:
async
with
self
.
lock
:
if
key
not
in
self
.
store
:
if
key
not
in
self
.
store
:
return
web
.
Response
(
return
web
.
Response
(
...
@@ -486,6 +576,52 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
...
@@ -486,6 +576,52 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text
=
"metadata deleted"
,
status
=
200
,
content_type
=
"application/json"
text
=
"metadata deleted"
,
status
=
200
,
content_type
=
"application/json"
)
)
async
def
_handle_kv_route
(
self
,
request
:
web
.
Request
):
method
=
request
.
method
if
method
==
"PUT"
:
return
await
self
.
_handle_kv_route_put
(
request
)
elif
method
==
"GET"
:
return
await
self
.
_handle_kv_route_get
(
request
)
else
:
return
web
.
Response
(
text
=
"Method not allowed"
,
status
=
405
,
content_type
=
"application/json"
)
async
def
_handle_kv_route_put
(
self
,
request
:
web
.
Request
):
data
=
await
request
.
json
()
identity
=
data
[
"identity"
]
role
=
data
[
"role"
]
serve_ip
=
data
[
"serve_ip"
]
serve_port
=
int
(
data
[
"serve_port"
])
# Assuming serve_port is an integer
tp_rank
=
int
(
data
[
"tp_rank"
])
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
async
with
self
.
lock
:
self
.
prefill_port_table
[
tp_rank
]
=
{
"serve_ip"
:
serve_ip
,
"serve_port"
:
serve_port
}
logger
.
info
(
f
"Registered Prefill tp_rank:
{
tp_rank
}
with serve_ip:
{
serve_ip
}
and serve_port:
{
serve_port
}
"
)
return
web
.
Response
(
text
=
"OK"
,
status
=
200
)
async
def
_handle_kv_route_get
(
self
,
request
:
web
.
Request
):
tp_rank
=
request
.
query
.
get
(
"tp_rank"
)
if
not
tp_rank
:
return
web
.
Response
(
text
=
"Missing tp_rank"
,
status
=
400
)
try
:
tp_rank
=
int
(
tp_rank
)
except
ValueError
:
return
web
.
Response
(
text
=
"tp_rank must be int"
,
status
=
400
)
# Find corresponding prefill info
async
with
self
.
lock
:
prefill_info
=
self
.
prefill_port_table
.
get
(
tp_rank
)
if
prefill_info
is
not
None
:
return
web
.
json_response
(
prefill_info
,
status
=
200
)
else
:
return
web
.
Response
(
text
=
"Not Found"
,
status
=
404
)
def
_run_server
(
self
):
def
_run_server
(
self
):
try
:
try
:
# Event Loop
# Event Loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment