Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
97ac42b6
Unverified
Commit
97ac42b6
authored
May 02, 2025
by
Yongtong Wu
Committed by
GitHub
May 02, 2025
Browse files
[PD] NIXL backend Prefill TP & Decode TP+DP (#5681)
parent
1acca3a2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
241 additions
and
71 deletions
+241
-71
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+241
-71
No files found.
python/sglang/srt/disaggregation/nixl/conn.py
View file @
97ac42b6
...
...
@@ -10,7 +10,7 @@ import threading
import
uuid
from
collections
import
defaultdict
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
TypeAlias
,
Union
import
numpy
as
np
import
numpy.typing
as
npt
...
...
@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger
=
logging
.
getLogger
(
__name__
)
NixlEngineInfo
:
TypeAlias
=
Dict
[
str
,
Union
[
str
,
int
]]
# From Mooncake backend.
def
group_concurrent_contiguous
(
src_indices
:
npt
.
NDArray
[
np
.
int64
],
dst_indices
:
npt
.
NDArray
[
np
.
int64
]
)
->
Tuple
[
List
[
npt
.
NDArray
[
np
.
int64
]],
List
[
npt
.
NDArray
[
np
.
int64
]]]:
src_groups
=
[]
dst_groups
=
[]
current_src
=
[
src_indices
[
0
]]
current_dst
=
[
dst_indices
[
0
]]
for
i
in
range
(
1
,
len
(
src_indices
)):
src_contiguous
=
src_indices
[
i
]
==
src_indices
[
i
-
1
]
+
1
dst_contiguous
=
dst_indices
[
i
]
==
dst_indices
[
i
-
1
]
+
1
if
src_contiguous
and
dst_contiguous
:
current_src
.
append
(
src_indices
[
i
])
current_dst
.
append
(
dst_indices
[
i
])
else
:
src_groups
.
append
(
current_src
)
dst_groups
.
append
(
current_dst
)
current_src
=
[
src_indices
[
i
]]
current_dst
=
[
dst_indices
[
i
]]
src_groups
.
append
(
current_src
)
dst_groups
.
append
(
current_dst
)
return
src_groups
,
dst_groups
GUARD
=
"NixlMsgGuard"
.
encode
(
"ascii"
)
@
dataclasses
.
dataclass
class
TransferInfo
:
...
...
@@ -45,19 +77,36 @@ class TransferInfo:
dst_aux_index
:
int
dst_gpu_id
:
int
def
is_dummy
(
self
):
return
self
.
endpoint
==
""
@
classmethod
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
return
cls
(
room
=
int
(
msg
[
0
].
decode
(
"ascii"
)),
endpoint
=
msg
[
1
].
decode
(
"ascii"
),
dst_port
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
agent_metadata
=
msg
[
3
],
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
4
])
//
8
}
Q"
,
msg
[
4
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
5
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
6
])
//
8
}
Q"
,
msg
[
6
])),
dst_aux_index
=
int
(
msg
[
7
].
decode
(
"ascii"
)),
dst_gpu_id
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
)
if
len
(
msg
)
==
1
:
# dummy msg
return
cls
(
room
=
int
(
msg
[
0
].
decode
(
"ascii"
)),
endpoint
=
""
,
dst_port
=
0
,
agent_metadata
=
b
""
,
dst_kv_ptrs
=
[],
dst_kv_indices
=
np
.
array
([],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
[],
dst_aux_index
=
0
,
dst_gpu_id
=
0
,
)
else
:
return
cls
(
room
=
int
(
msg
[
0
].
decode
(
"ascii"
)),
endpoint
=
msg
[
1
].
decode
(
"ascii"
),
dst_port
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
agent_metadata
=
msg
[
3
],
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
4
])
//
8
}
Q"
,
msg
[
4
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
5
],
dtype
=
np
.
int64
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
6
])
//
8
}
Q"
,
msg
[
6
])),
dst_aux_index
=
int
(
msg
[
7
].
decode
(
"ascii"
)),
dst_gpu_id
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
)
@
dataclasses
.
dataclass
...
...
@@ -98,6 +147,19 @@ class NixlKVManager(BaseKVManager):
# for p/d multi node infer
self
.
bootstrap_port
=
server_args
.
disaggregation_bootstrap_port
self
.
dist_init_addr
=
server_args
.
dist_init_addr
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_rank
=
args
.
engine_rank
self
.
enable_dp_attention
=
server_args
.
enable_dp_attention
if
self
.
enable_dp_attention
:
assert
(
server_args
.
dp_size
>
1
),
"If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
self
.
dp_size
=
server_args
.
dp_size
self
.
tp_size_of_dp
=
server_args
.
tp_size
//
server_args
.
dp_size
self
.
attn_tp_rank
=
args
.
engine_rank
%
self
.
tp_size_of_dp
self
.
dp_rank
=
args
.
engine_rank
//
self
.
tp_size_of_dp
self
.
rank_port
=
None
self
.
server_socket
=
zmq
.
Context
().
socket
(
zmq
.
PULL
)
self
.
register_buffer_to_engine
()
...
...
@@ -110,7 +172,8 @@ class NixlKVManager(BaseKVManager):
self
.
_start_bootstrap_thread
()
self
.
_register_to_bootstrap
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
connection_pool
:
Dict
[
str
,
Dict
[
str
,
Union
[
str
,
int
]]]
=
{}
# bootstrap key -> (remote_engine_rank -> possible remote source info)
self
.
prefill_peer_infos
:
Dict
[
str
,
list
[
Dict
[
int
,
NixlEngineInfo
]]]
=
{}
self
.
transfer_statuses
:
Dict
[
int
,
TransferStatus
]
=
defaultdict
(
TransferStatus
)
...
...
@@ -126,6 +189,7 @@ class NixlKVManager(BaseKVManager):
):
kv_addrs
.
append
((
kv_data_ptr
,
kv_data_len
,
self
.
kv_args
.
gpu_id
,
""
))
self
.
kv_descs
=
self
.
agent
.
register_memory
(
kv_addrs
,
"VRAM"
,
is_sorted
=
True
)
logger
.
debug
(
f
"Register kv tensors, len(kv_addr)=
{
len
(
kv_addrs
)
}
"
)
if
not
self
.
kv_descs
:
raise
Exception
(
"NIXL memory registration failed for kv tensors"
)
aux_addrs
=
[]
...
...
@@ -134,6 +198,7 @@ class NixlKVManager(BaseKVManager):
):
aux_addrs
.
append
((
aux_data_ptr
,
aux_data_len
,
0
,
""
))
self
.
aux_descs
=
self
.
agent
.
register_memory
(
aux_addrs
,
"DRAM"
,
is_sorted
=
True
)
logger
.
debug
(
f
"Register aux tensors, len(aux_addrs)=
{
len
(
aux_addrs
)
}
"
)
if
not
self
.
aux_descs
:
raise
Exception
(
"NIXL memory registration failed for aux tensors"
)
...
...
@@ -157,6 +222,12 @@ class NixlKVManager(BaseKVManager):
dst_gpu_id
:
int
,
notif
:
str
,
):
# group by indices
prefill_kv_blocks
,
dst_kv_blocks
=
group_concurrent_contiguous
(
prefill_kv_indices
,
dst_kv_indices
)
logger
.
debug
(
f
"sending kvcache to
{
peer_name
}
with notif
{
notif
}
"
)
# Make descs
num_layers
=
len
(
self
.
kv_args
.
kv_data_ptrs
)
src_addrs
=
[]
...
...
@@ -166,12 +237,16 @@ class NixlKVManager(BaseKVManager):
dst_ptr
=
dst_kv_ptrs
[
layer_id
]
item_len
=
self
.
kv_args
.
kv_item_lens
[
layer_id
]
for
prefill_index
,
decode_index
in
zip
(
prefill_kv_
indice
s
,
dst_kv_
indice
s
):
src_addr
=
src_ptr
+
int
(
prefill_index
)
*
item_len
dst_addr
=
dst_ptr
+
int
(
decode_index
)
*
item_len
length
=
item_len
for
prefill_index
,
decode_index
in
zip
(
prefill_kv_
block
s
,
dst_kv_
block
s
):
src_addr
=
src_ptr
+
int
(
prefill_index
[
0
]
)
*
item_len
dst_addr
=
dst_ptr
+
int
(
decode_index
[
0
]
)
*
item_len
length
=
item_len
*
len
(
prefill_index
)
src_addrs
.
append
((
src_addr
,
length
,
self
.
kv_args
.
gpu_id
))
dst_addrs
.
append
((
dst_addr
,
length
,
dst_gpu_id
))
logger
.
debug
(
f
"len(src_addrs): before group:
{
len
(
prefill_kv_indices
)
}
, after group:
{
len
(
src_addrs
)
}
"
)
src_descs
=
self
.
agent
.
get_xfer_descs
(
src_addrs
,
"VRAM"
,
is_sorted
=
True
)
dst_descs
=
self
.
agent
.
get_xfer_descs
(
dst_addrs
,
"VRAM"
,
is_sorted
=
True
)
# Transfer data
...
...
@@ -180,7 +255,7 @@ class NixlKVManager(BaseKVManager):
src_descs
,
dst_descs
,
peer_name
,
notif
.
encode
(
"ascii"
),
notif
.
encode
(
"ascii"
),
# type: ignore
)
if
not
xfer_handle
:
raise
Exception
(
"KVSender failed to create transfer"
)
...
...
@@ -213,7 +288,7 @@ class NixlKVManager(BaseKVManager):
src_descs
,
dst_descs
,
peer_name
,
notif
.
encode
(
"ascii"
),
notif
.
encode
(
"ascii"
),
# type: ignore
)
if
not
xfer_handle
:
raise
Exception
(
"KVSender failed to create transfer"
)
...
...
@@ -240,6 +315,9 @@ class NixlKVManager(BaseKVManager):
req
=
self
.
transfer_infos
[
bootstrap_room
]
assert
bootstrap_room
==
req
.
room
if
req
.
is_dummy
():
return
[]
peer_name
=
self
.
_add_remote
(
bootstrap_room
,
req
.
agent_metadata
)
chunked_dst_kv_indice
=
req
.
dst_kv_indices
[
index_slice
]
assert
len
(
chunked_dst_kv_indice
)
==
len
(
kv_indices
)
...
...
@@ -256,6 +334,7 @@ class NixlKVManager(BaseKVManager):
handles
=
[
kv_xfer_handle
]
# Only the last chunk we need to send the aux data.
if
is_last
:
assert
aux_index
is
not
None
aux_xfer_handle
=
self
.
send_aux
(
peer_name
,
aux_index
,
...
...
@@ -325,6 +404,13 @@ class NixlKVManager(BaseKVManager):
"""This thread recvs transfer info from the decode engine"""
while
True
:
waiting_req_bytes
=
self
.
server_socket
.
recv_multipart
()
logger
.
debug
(
f
"Received multipart with total byte size
{
sum
(
len
(
x
)
for
x
in
waiting_req_bytes
)
}
"
)
assert
(
waiting_req_bytes
[
0
]
==
GUARD
),
f
"First message should be
{
GUARD
}
. Foreign traffic?"
waiting_req_bytes
=
waiting_req_bytes
[
1
:]
room
=
waiting_req_bytes
[
0
].
decode
(
"ascii"
)
if
room
==
"None"
:
continue
...
...
@@ -372,14 +458,13 @@ class NixlKVSender(BaseKVSender):
def
poll
(
self
)
->
KVPoll
:
if
not
self
.
has_sent
:
return
KVPoll
.
WaitingForInput
return
KVPoll
.
WaitingForInput
# type: ignore
states
=
[
self
.
kv_mgr
.
agent
.
check_xfer_state
(
x
)
for
x
in
self
.
xfer_handles
]
if
all
([
x
==
"DONE"
for
x
in
states
]):
return
KVPoll
.
Success
return
KVPoll
.
Success
# type: ignore
if
any
([
x
==
"ERR"
for
x
in
states
]):
raise
Exception
(
"KVSender transfer encountered an error."
)
return
KVPoll
.
WaitingForInput
return
KVPoll
.
WaitingForInput
# type: ignore
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVSender Exception"
)
...
...
@@ -401,7 +486,7 @@ class NixlKVReceiver(BaseKVReceiver):
# NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key
=
f
"
{
self
.
bootstrap_addr
}
_
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
if
bootstrap_key
not
in
self
.
kv_mgr
.
connection_pool
:
if
bootstrap_key
not
in
self
.
kv_mgr
.
prefill_peer_infos
:
self
.
bootstrap_info
=
self
.
_get_bootstrap_info_from_server
(
self
.
kv_mgr
.
kv_args
.
engine_rank
)
...
...
@@ -410,25 +495,79 @@ class NixlKVReceiver(BaseKVReceiver):
f
"Could not fetch bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
else
:
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
=
self
.
bootstrap_info
self
.
kv_mgr
.
prefill_peer_infos
[
bootstrap_key
]
=
self
.
bootstrap_info
else
:
self
.
bootstrap_info
=
self
.
kv_mgr
.
connection_pool
[
bootstrap_key
]
self
.
bootstrap_info
=
self
.
kv_mgr
.
prefill_peer_infos
[
bootstrap_key
]
assert
self
.
bootstrap_info
is
not
None
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
):
# return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
# In each dict, there are multiple possible remotes named "equal sources".
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
def
_get_bootstrap_info_from_server
(
self
,
engine_rank
)
->
Optional
[
List
[
Dict
[
int
,
NixlEngineInfo
]]]:
"""Fetch the bootstrap info from the bootstrap server."""
try
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
if
self
.
kv_mgr
.
enable_dp_attention
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route"
response
=
requests
.
get
(
url
)
if
response
.
status_code
!=
200
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
bootstrap_info
=
response
.
json
()
return
bootstrap_info
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
assert
isinstance
(
bootstrap_info
,
dict
)
bootstrap_info
=
{
int
(
k
):
v
for
k
,
v
in
bootstrap_info
.
items
()}
# split out who need to send to this rank.
# currently for dpsk mla model, those ranks share the same latent cache.
# pick one as the real source
prefill_tp_size
=
len
(
bootstrap_info
.
keys
())
assert
(
prefill_tp_size
>=
self
.
kv_mgr
.
tp_size_of_dp
),
f
"Only support Prefill TP size >= Decode TP size of DP, now we have
{
prefill_tp_size
}
vs
{
self
.
kv_mgr
.
tp_size_of_dp
}
"
num_remote_tp_rank_we_managed
=
(
prefill_tp_size
//
self
.
kv_mgr
.
tp_size_of_dp
)
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
remote_tp_ranks
=
list
(
range
(
0
,
prefill_tp_size
))
# split it into tp_size_of_dp parts and get our part
remote_tp_ranks_grouped
=
[
remote_tp_ranks
[
i
:
i
+
num_remote_tp_rank_we_managed
]
for
i
in
range
(
0
,
prefill_tp_size
,
self
.
kv_mgr
.
tp_size_of_dp
)
]
managed_ranks
=
remote_tp_ranks_grouped
[
self
.
kv_mgr
.
attn_tp_rank
]
assert
len
(
managed_ranks
)
==
num_remote_tp_rank_we_managed
logger
.
debug
(
f
"Rank
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
source can be
{
managed_ranks
}
"
)
return
None
return
[
{
rk
:
bootstrap_info
[
rk
]
for
rk
in
bootstrap_info
.
keys
()
if
rk
in
managed_ranks
}
]
else
:
url
=
f
"http://
{
self
.
bootstrap_addr
}
/route?engine_rank=
{
engine_rank
}
"
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
bootstrap_info
=
response
.
json
()
return
[{
engine_rank
:
bootstrap_info
}]
else
:
logger
.
error
(
f
"Failed to get prefill server info:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
return
None
except
Exception
as
e
:
logger
.
error
(
f
"Error fetching prefill info from bootstrap:
{
e
}
"
)
return
None
...
...
@@ -440,43 +579,67 @@ class NixlKVReceiver(BaseKVReceiver):
return
socket
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
self
.
prefill_server_url
=
(
f
"
{
self
.
bootstrap_info
[
'rank_ip'
]
}
:
{
self
.
bootstrap_info
[
'rank_port'
]
}
"
)
logger
.
debug
(
f
"Fetched bootstrap info:
{
self
.
bootstrap_info
}
for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
packed_aux_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
aux_data_ptrs
)
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
).
send_multipart
(
[
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
get_local_ip_by_remote
().
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
rank_port
).
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
get_agent_metadata
(),
packed_kv_data_ptrs
,
kv_indices
.
tobytes
(),
packed_aux_data_ptrs
,
str
(
aux_index
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
gpu_id
).
encode
(
"ascii"
),
assert
self
.
bootstrap_info
is
not
None
assert
self
.
bootstrap_room
is
not
None
for
equal_sources
in
self
.
bootstrap_info
:
remote_rank
=
list
(
equal_sources
.
keys
())[
self
.
bootstrap_room
%
len
(
equal_sources
)
]
)
self
.
prefill_server_url
=
f
"
{
equal_sources
[
remote_rank
][
'rank_ip'
]
}
:
{
equal_sources
[
remote_rank
][
'rank_port'
]
}
"
logger
.
debug
(
f
"Fetched bootstrap info for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
, source:
{
remote_rank
}
, all:
{
list
(
equal_sources
.
keys
())
}
"
)
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
packed_aux_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
aux_data_ptrs
)
logger
.
debug
(
f
"Sending to
{
self
.
prefill_server_url
}
with bootstrap room
{
self
.
bootstrap_room
}
"
)
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
).
send_multipart
(
[
GUARD
,
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
get_local_ip_by_remote
().
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
rank_port
).
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
get_agent_metadata
(),
packed_kv_data_ptrs
,
kv_indices
.
tobytes
(),
packed_aux_data_ptrs
,
str
(
aux_index
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
gpu_id
).
encode
(
"ascii"
),
]
)
for
dummy_rank
in
equal_sources
.
keys
():
if
dummy_rank
==
remote_rank
:
continue
dummy_info
=
equal_sources
[
dummy_rank
]
dummy_url
=
f
"
{
dummy_info
[
'rank_ip'
]
}
:
{
dummy_info
[
'rank_port'
]
}
"
self
.
_connect
(
"tcp://"
+
dummy_url
).
send_multipart
(
[
GUARD
,
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
]
)
self
.
started_transfer
=
True
def
poll
(
self
)
->
KVPoll
:
if
not
self
.
started_transfer
:
return
KVPoll
.
WaitingForInput
return
KVPoll
.
WaitingForInput
# type: ignore
self
.
kv_mgr
.
update_transfer_status
()
if
self
.
kv_mgr
.
check_transfer_done
(
self
.
bootstrap_room
):
return
KVPoll
.
Success
return
KVPoll
.
WaitingForInput
if
self
.
kv_mgr
.
check_transfer_done
(
self
.
bootstrap_room
):
# type: ignore
return
KVPoll
.
Success
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVReceiver Exception"
)
...
...
@@ -484,6 +647,7 @@ class NixlKVReceiver(BaseKVReceiver):
class
NixlKVBootstrapServer
(
BaseKVBootstrapServer
):
def
__init__
(
self
,
port
:
int
):
logger
.
debug
(
f
"NixlKVBootstrapServer started on port
{
port
}
"
)
self
.
port
=
port
self
.
app
=
web
.
Application
()
self
.
store
=
dict
()
...
...
@@ -564,13 +728,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
engine_rank
=
int
(
data
[
"engine_rank"
])
agent_name
=
data
[
"agent_name"
]
# Add lock to make sure thread-safe
if
role
==
"Prefill"
:
self
.
prefill_port_table
[
engine_rank
]
=
{
"rank_ip"
:
rank_ip
,
"rank_port"
:
rank_port
,
"agent_name"
:
agent_name
,
}
async
with
self
.
lock
:
self
.
prefill_port_table
[
engine_rank
]
=
{
"rank_ip"
:
rank_ip
,
"rank_port"
:
rank_port
,
"agent_name"
:
agent_name
,
}
logger
.
info
(
f
"Registered Prefill boostrap:
{
engine_rank
}
with rank_ip:
{
rank_ip
}
and rank_port:
{
rank_port
}
and name:
{
agent_name
}
"
)
...
...
@@ -580,7 +744,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
async
def
_handle_route_get
(
self
,
request
:
web
.
Request
):
engine_rank
=
request
.
query
.
get
(
"engine_rank"
)
if
not
engine_rank
:
return
web
.
Response
(
text
=
"Missing rank"
,
status
=
400
)
logger
.
debug
(
f
"No engine_rank specified, return all
{
len
(
self
.
prefill_port_table
)
}
engine infos as a dict"
)
# Return a dict of all engine_rank
async
with
self
.
lock
:
bootstrap_info
=
self
.
prefill_port_table
return
web
.
json_response
(
bootstrap_info
,
status
=
200
)
# Find corresponding prefill info
async
with
self
.
lock
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment