Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5f527834
Unverified
Commit
5f527834
authored
Jun 24, 2025
by
Trevor Morris
Committed by
GitHub
Jun 24, 2025
Browse files
[PD] NIXL: Register kv args in advance and cleanup finished requests (#6717)
parent
9f1787fa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
46 deletions
+94
-46
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+94
-46
No files found.
python/sglang/srt/disaggregation/nixl/conn.py
View file @
5f527834
...
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
...
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
NixlEngineInfo
:
TypeAlias
=
Dict
[
str
,
Union
[
str
,
int
]]
GUARD
=
"NixlMsgGuard"
.
encode
(
"ascii"
)
GUARD
=
"NixlMsgGuard"
.
encode
(
"ascii"
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TransferInfo
:
class
TransferInfo
:
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
room
:
int
room
:
int
endpoint
:
str
endpoint
:
str
dst_port
:
int
dst_port
:
int
agent_metadata
:
bytes
agent_name
:
str
agent_name
:
str
dst_kv_ptrs
:
list
[
int
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int32
]
dst_kv_indices
:
npt
.
NDArray
[
np
.
int32
]
dst_aux_ptrs
:
list
[
int
]
dst_aux_index
:
int
dst_aux_index
:
int
dst_gpu_id
:
int
required_dst_info_num
:
int
required_dst_info_num
:
int
def
is_dummy
(
self
):
def
is_dummy
(
self
):
...
@@ -59,14 +55,37 @@ class TransferInfo:
...
@@ -59,14 +55,37 @@ class TransferInfo:
room
=
int
(
msg
[
0
].
decode
(
"ascii"
)),
room
=
int
(
msg
[
0
].
decode
(
"ascii"
)),
endpoint
=
msg
[
1
].
decode
(
"ascii"
),
endpoint
=
msg
[
1
].
decode
(
"ascii"
),
dst_port
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
dst_port
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
agent_metadata
=
msg
[
3
],
agent_name
=
msg
[
3
].
decode
(
"ascii"
),
agent_name
=
msg
[
4
].
decode
(
"ascii"
),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
4
],
dtype
=
np
.
int32
),
dst_aux_index
=
int
(
msg
[
5
].
decode
(
"ascii"
)),
required_dst_info_num
=
int
(
msg
[
6
].
decode
(
"ascii"
)),
)
@
dataclasses
.
dataclass
class
KVArgsRegisterInfo
:
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
room
:
str
endpoint
:
str
dst_port
:
int
agent_name
:
str
agent_metadata
:
bytes
dst_kv_ptrs
:
list
[
int
]
dst_aux_ptrs
:
list
[
int
]
gpu_id
:
int
@
classmethod
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
return
cls
(
room
=
str
(
msg
[
0
].
decode
(
"ascii"
)),
endpoint
=
msg
[
1
].
decode
(
"ascii"
),
dst_port
=
int
(
msg
[
2
].
decode
(
"ascii"
)),
agent_name
=
msg
[
3
].
decode
(
"ascii"
),
agent_metadata
=
msg
[
4
],
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_kv_indices
=
np
.
frombuffer
(
msg
[
6
],
dtype
=
np
.
int32
),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
6
])
//
8
}
Q"
,
msg
[
6
])),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
7
])
//
8
}
Q"
,
msg
[
7
])),
gpu_id
=
int
(
msg
[
7
].
decode
(
"ascii"
)),
dst_aux_index
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
dst_gpu_id
=
int
(
msg
[
9
].
decode
(
"ascii"
)),
required_dst_info_num
=
int
(
msg
[
10
].
decode
(
"ascii"
)),
)
)
...
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
...
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
self
.
register_buffer_to_engine
()
self
.
register_buffer_to_engine
()
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
request_status
=
{}
self
.
request_status
:
Dict
[
int
,
KVPoll
]
=
{}
self
.
transfer_infos
:
Dict
[
int
,
TransferInfo
]
=
{}
self
.
transfer_infos
:
Dict
[
int
,
Dict
[
str
,
TransferInfo
]
]
=
{}
self
.
peer_names
:
Dict
[
str
,
str
]
=
{}
self
.
decode_kv_args_table
:
Dict
[
str
,
KVArgsRegisterInfo
]
=
{}
self
.
_start_bootstrap_thread
()
self
.
_start_bootstrap_thread
()
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
transfer_statuses
:
Dict
[
int
,
TransferStatus
]
=
defaultdict
(
self
.
transfer_statuses
:
Dict
[
int
,
TransferStatus
]
=
defaultdict
(
...
@@ -154,10 +173,13 @@ class NixlKVManager(CommonKVManager):
...
@@ -154,10 +173,13 @@ class NixlKVManager(CommonKVManager):
if
not
self
.
aux_descs
:
if
not
self
.
aux_descs
:
raise
Exception
(
"NIXL memory registration failed for aux tensors"
)
raise
Exception
(
"NIXL memory registration failed for aux tensors"
)
def
_add_remote
(
self
,
agent_name
:
str
,
agent_metadata
:
bytes
):
def
_add_remote_peer
(
self
,
decode_kv_args
:
KVArgsRegisterInfo
):
if
agent_name
not
in
self
.
peer_names
:
agent_name
=
decode_kv_args
.
agent_name
self
.
peer_names
[
agent_name
]
=
self
.
agent
.
add_remote_agent
(
agent_metadata
)
if
agent_name
in
self
.
decode_kv_args_table
:
return
self
.
peer_names
[
agent_name
]
logger
.
info
(
f
"Peer
{
agent_name
}
was already registered, ignoring."
)
return
self
.
decode_kv_args_table
[
agent_name
]
=
decode_kv_args
self
.
agent
.
add_remote_agent
(
decode_kv_args
.
agent_metadata
)
def
send_kvcache
(
def
send_kvcache
(
self
,
self
,
...
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
...
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
if
req
.
is_dummy
():
if
req
.
is_dummy
():
continue
continue
peer_name
=
self
.
_add_remote
(
req
.
agent_name
,
req
.
agent_metadata
)
chunked_dst_kv_indice
=
req
.
dst_kv_indices
[
index_slice
]
chunked_dst_kv_indice
=
req
.
dst_kv_indices
[
index_slice
]
assert
len
(
chunked_dst_kv_indice
)
==
len
(
kv_indices
)
assert
len
(
chunked_dst_kv_indice
)
==
len
(
kv_indices
)
assert
req
.
agent_name
in
self
.
decode_kv_args_table
notif
=
"_"
.
join
([
str
(
req
.
room
),
"kv"
,
str
(
chunk_id
),
str
(
int
(
is_last
))])
notif
=
"_"
.
join
([
str
(
req
.
room
),
"kv"
,
str
(
chunk_id
),
str
(
int
(
is_last
))])
kv_xfer_handle
=
self
.
send_kvcache
(
kv_xfer_handle
=
self
.
send_kvcache
(
peer
_name
,
req
.
agent
_name
,
kv_indices
,
kv_indices
,
req
.
dst_kv_ptrs
,
self
.
decode_kv_args_table
[
req
.
agent_name
]
.
dst_kv_ptrs
,
chunked_dst_kv_indice
,
chunked_dst_kv_indice
,
req
.
dst_
gpu_id
,
self
.
decode_kv_args_table
[
req
.
agent_name
].
gpu_id
,
notif
,
notif
,
)
)
handles
.
append
(
kv_xfer_handle
)
handles
.
append
(
kv_xfer_handle
)
...
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
...
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
if
is_last
:
if
is_last
:
assert
aux_index
is
not
None
assert
aux_index
is
not
None
aux_xfer_handle
=
self
.
send_aux
(
aux_xfer_handle
=
self
.
send_aux
(
peer
_name
,
req
.
agent
_name
,
aux_index
,
aux_index
,
req
.
dst_aux_ptrs
,
self
.
decode_kv_args_table
[
req
.
agent_name
]
.
dst_aux_ptrs
,
req
.
dst_aux_index
,
req
.
dst_aux_index
,
str
(
req
.
room
)
+
"_aux"
,
str
(
req
.
room
)
+
"_aux"
,
)
)
handles
.
append
(
aux_xfer_handle
)
handles
.
append
(
aux_xfer_handle
)
if
is_last
:
del
self
.
transfer_infos
[
bootstrap_room
]
return
handles
return
handles
def
update_transfer_status
(
self
):
def
update_transfer_status
(
self
):
...
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
...
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
),
f
"First message should be
{
GUARD
}
. Foreign traffic?"
),
f
"First message should be
{
GUARD
}
. Foreign traffic?"
waiting_req_bytes
=
waiting_req_bytes
[
1
:]
waiting_req_bytes
=
waiting_req_bytes
[
1
:]
room
=
waiting_req_bytes
[
0
].
decode
(
"ascii"
)
room
=
waiting_req_bytes
[
0
].
decode
(
"ascii"
)
agent_name
=
waiting_req_bytes
[
3
].
decode
(
"ascii"
)
required_dst_info_num
=
int
(
waiting_req_bytes
[
10
].
decode
(
"ascii"
))
if
room
==
"None"
:
# Register new peer and save KV base pointers.
self
.
_add_remote_peer
(
KVArgsRegisterInfo
.
from_zmq
(
waiting_req_bytes
)
)
logger
.
debug
(
f
"Register KVArgs from
{
agent_name
}
successfully"
)
continue
room
=
int
(
room
)
room
=
int
(
room
)
agent_name
=
waiting_req_bytes
[
4
].
decode
(
"ascii"
)
if
room
not
in
self
.
transfer_infos
:
if
room
not
in
self
.
transfer_infos
:
self
.
transfer_infos
[
room
]
=
{}
self
.
transfer_infos
[
room
]
=
{}
self
.
transfer_infos
[
room
][
agent_name
]
=
TransferInfo
.
from_zmq
(
self
.
transfer_infos
[
room
][
agent_name
]
=
TransferInfo
.
from_zmq
(
waiting_req_bytes
waiting_req_bytes
)
)
required_dst_info_num
=
self
.
transfer_infos
[
room
][
agent_name
].
required_dst_info_num
logger
.
debug
(
f
"got info
{
room
=
}
{
agent_name
=
}
{
required_dst_info_num
=
}
"
)
logger
.
debug
(
f
"got info
{
room
=
}
{
agent_name
=
}
{
required_dst_info_num
=
}
"
)
if
len
(
self
.
transfer_infos
[
room
])
==
required_dst_info_num
:
if
len
(
self
.
transfer_infos
[
room
])
==
required_dst_info_num
:
logger
.
debug
(
f
"
{
room
=
}
is bootstrapped"
)
logger
.
debug
(
f
"
{
room
=
}
is bootstrapped"
)
...
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
...
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
self
.
chunk_id
+=
1
self
.
chunk_id
+=
1
if
is_last
:
if
is_last
:
self
.
has_sent
=
True
self
.
has_sent
=
True
del
self
.
kv_mgr
.
request_status
[
self
.
bootstrap_room
]
def
poll
(
self
)
->
KVPoll
:
def
poll
(
self
)
->
KVPoll
:
if
not
self
.
has_sent
:
if
not
self
.
has_sent
:
...
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
data_parallel_rank
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
started_transfer
=
False
self
.
started_transfer
=
False
self
.
conclude_state
=
None
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
data_parallel_rank
)
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
data_parallel_rank
)
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int32
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int32
],
aux_index
:
Optional
[
int
]
=
None
):
...
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
f
"Fetched bootstrap info:
{
bootstrap_info
}
for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
f
"Fetched bootstrap info:
{
bootstrap_info
}
for engine rank:
{
self
.
kv_mgr
.
kv_args
.
engine_rank
}
"
)
)
is_dummy
=
bootstrap_info
[
"is_dummy"
]
is_dummy
=
bootstrap_info
[
"is_dummy"
]
# TODO: send_kv_args earlier
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
packed_aux_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
aux_data_ptrs
)
logger
.
debug
(
logger
.
debug
(
f
"Sending to
{
self
.
prefill_server_url
}
with bootstrap room
{
self
.
bootstrap_room
}
"
f
"Sending to
{
self
.
prefill_server_url
}
with bootstrap room
{
self
.
bootstrap_room
}
{
is_dummy
=
}
"
)
)
sock
,
lock
=
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
)
sock
,
lock
=
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
)
with
lock
:
with
lock
:
...
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
str
(
self
.
bootstrap_room
).
encode
(
"ascii"
),
get_local_ip_by_remote
().
encode
(
"ascii"
),
get_local_ip_by_remote
().
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
rank_port
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
rank_port
).
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
get_agent_metadata
(),
self
.
kv_mgr
.
agent
.
name
.
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
name
.
encode
(
"ascii"
),
packed_kv_data_ptrs
,
kv_indices
.
tobytes
()
if
not
is_dummy
else
b
""
,
kv_indices
.
tobytes
()
if
not
is_dummy
else
b
""
,
packed_aux_data_ptrs
,
str
(
aux_index
).
encode
(
"ascii"
),
str
(
aux_index
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
gpu_id
).
encode
(
"ascii"
),
str
(
self
.
required_dst_info_num
).
encode
(
"ascii"
),
str
(
self
.
required_dst_info_num
).
encode
(
"ascii"
),
]
]
)
)
...
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
self
.
started_transfer
=
True
self
.
started_transfer
=
True
def
poll
(
self
)
->
KVPoll
:
def
poll
(
self
)
->
KVPoll
:
if
self
.
conclude_state
is
not
None
:
return
self
.
conclude_state
if
not
self
.
started_transfer
:
if
not
self
.
started_transfer
:
return
KVPoll
.
WaitingForInput
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
self
.
kv_mgr
.
update_transfer_status
()
self
.
kv_mgr
.
update_transfer_status
()
if
self
.
kv_mgr
.
check_transfer_done
(
self
.
bootstrap_room
):
# type: ignore
if
self
.
kv_mgr
.
check_transfer_done
(
self
.
bootstrap_room
):
# type: ignore
self
.
conclude_state
=
KVPoll
.
Success
del
self
.
kv_mgr
.
transfer_statuses
[
self
.
bootstrap_room
]
return
KVPoll
.
Success
# type: ignore
return
KVPoll
.
Success
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
def
_register_kv_args
(
self
):
def
_register_kv_args
(
self
):
pass
for
bootstrap_info
in
self
.
bootstrap_infos
:
self
.
prefill_server_url
=
(
f
"
{
bootstrap_info
[
'rank_ip'
]
}
:
{
bootstrap_info
[
'rank_port'
]
}
"
)
packed_kv_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
kv_data_ptrs
)
packed_aux_data_ptrs
=
b
""
.
join
(
struct
.
pack
(
"Q"
,
ptr
)
for
ptr
in
self
.
kv_mgr
.
kv_args
.
aux_data_ptrs
)
sock
,
lock
=
self
.
_connect
(
"tcp://"
+
self
.
prefill_server_url
)
with
lock
:
sock
.
send_multipart
(
[
GUARD
,
"None"
.
encode
(
"ascii"
),
get_local_ip_by_remote
().
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
rank_port
).
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
name
.
encode
(
"ascii"
),
self
.
kv_mgr
.
agent
.
get_agent_metadata
(),
packed_kv_data_ptrs
,
packed_aux_data_ptrs
,
str
(
self
.
kv_mgr
.
kv_args
.
gpu_id
).
encode
(
"ascii"
),
]
)
def
failure_exception
(
self
):
def
failure_exception
(
self
):
raise
Exception
(
"Fake KVReceiver Exception"
)
raise
Exception
(
"Fake KVReceiver Exception"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment