Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
25ef53f0
Unverified
Commit
25ef53f0
authored
Aug 21, 2025
by
Shangming Cai
Committed by
GitHub
Aug 20, 2025
Browse files
[PD] Fix nvlink transport accuracy through transferring metadata with tcp (#9261)
Signed-off-by:
Shangming Cai
<
csmthu@gmail.com
>
parent
c674bf9c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
25 deletions
+112
-25
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+110
-24
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+2
-1
No files found.
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
25ef53f0
...
@@ -2,6 +2,7 @@ from __future__ import annotations
...
@@ -2,6 +2,7 @@ from __future__ import annotations
import
asyncio
import
asyncio
import
concurrent.futures
import
concurrent.futures
import
ctypes
import
dataclasses
import
dataclasses
import
logging
import
logging
import
os
import
os
...
@@ -138,7 +139,29 @@ class KVArgsRegisterInfo:
...
@@ -138,7 +139,29 @@ class KVArgsRegisterInfo:
)
)
class
AuxDataCodec
:
"""Handles serialization and deserialization of auxiliary data buffers"""
@
staticmethod
def
serialize_data_from_buffer
(
src_addr
,
data_length
):
"""Serialize data from memory buffer to bytes"""
buffer
=
(
ctypes
.
c_byte
*
data_length
).
from_address
(
src_addr
)
return
bytes
(
buffer
)
@
staticmethod
def
deserialize_data_to_buffer
(
kv_args
,
buffer_index
,
aux_index
,
data
):
"""Deserialize bytes into target memory buffer"""
dst_aux_ptr
=
kv_args
.
aux_data_ptrs
[
buffer_index
]
item_len
=
kv_args
.
aux_item_lens
[
buffer_index
]
dst_addr
=
dst_aux_ptr
+
item_len
*
aux_index
buffer
=
(
ctypes
.
c_byte
*
len
(
data
)).
from_address
(
dst_addr
)
buffer
[:]
=
data
return
class
MooncakeKVManager
(
BaseKVManager
):
class
MooncakeKVManager
(
BaseKVManager
):
AUX_DATA_HEADER
=
b
"AUX_DATA"
def
__init__
(
def
__init__
(
self
,
self
,
args
:
KVArgs
,
args
:
KVArgs
,
...
@@ -283,17 +306,6 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -283,17 +306,6 @@ class MooncakeKVManager(BaseKVManager):
if
not
transfer_blocks
:
if
not
transfer_blocks
:
return
0
return
0
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
if
self
.
enable_custom_mem_pool
:
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
for
src_addr
,
dst_addr
,
length
in
transfer_blocks
:
status
=
self
.
engine
.
transfer_sync
(
mooncake_session_id
,
src_addr
,
dst_addr
,
length
)
if
status
!=
0
:
return
status
return
0
else
:
src_addrs
,
dst_addrs
,
lengths
=
zip
(
*
transfer_blocks
)
src_addrs
,
dst_addrs
,
lengths
=
zip
(
*
transfer_blocks
)
return
self
.
engine
.
batch_transfer_sync
(
return
self
.
engine
.
batch_transfer_sync
(
mooncake_session_id
,
list
(
src_addrs
),
list
(
dst_addrs
),
list
(
lengths
)
mooncake_session_id
,
list
(
src_addrs
),
list
(
dst_addrs
),
list
(
lengths
)
...
@@ -570,11 +582,14 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -570,11 +582,14 @@ class MooncakeKVManager(BaseKVManager):
def
send_aux
(
def
send_aux
(
self
,
self
,
mooncake_session_id
:
str
,
req
:
TransferInfo
,
prefill_aux_index
:
int
,
prefill_aux_index
:
int
,
dst_aux_ptrs
:
list
[
int
],
dst_aux_ptrs
:
list
[
int
],
dst_aux_index
:
int
,
):
):
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
if
self
.
enable_custom_mem_pool
:
return
self
.
send_aux_tcp
(
req
,
prefill_aux_index
,
dst_aux_ptrs
)
transfer_blocks
=
[]
transfer_blocks
=
[]
prefill_aux_ptrs
=
self
.
kv_args
.
aux_data_ptrs
prefill_aux_ptrs
=
self
.
kv_args
.
aux_data_ptrs
prefill_aux_item_lens
=
self
.
kv_args
.
aux_item_lens
prefill_aux_item_lens
=
self
.
kv_args
.
aux_item_lens
...
@@ -582,10 +597,59 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -582,10 +597,59 @@ class MooncakeKVManager(BaseKVManager):
for
i
,
dst_aux_ptr
in
enumerate
(
dst_aux_ptrs
):
for
i
,
dst_aux_ptr
in
enumerate
(
dst_aux_ptrs
):
length
=
prefill_aux_item_lens
[
i
]
length
=
prefill_aux_item_lens
[
i
]
src_addr
=
prefill_aux_ptrs
[
i
]
+
length
*
prefill_aux_index
src_addr
=
prefill_aux_ptrs
[
i
]
+
length
*
prefill_aux_index
dst_addr
=
dst_aux_ptrs
[
i
]
+
length
*
dst_aux_index
dst_addr
=
dst_aux_ptrs
[
i
]
+
length
*
req
.
dst_aux_index
transfer_blocks
.
append
((
src_addr
,
dst_addr
,
length
))
transfer_blocks
.
append
((
src_addr
,
dst_addr
,
length
))
return
self
.
_transfer_data
(
mooncake_session_id
,
transfer_blocks
)
return
self
.
_transfer_data
(
req
.
mooncake_session_id
,
transfer_blocks
)
def
send_aux_tcp
(
self
,
req
:
TransferInfo
,
prefill_aux_index
:
int
,
dst_aux_ptrs
:
list
[
int
],
):
prefill_aux_ptrs
=
self
.
kv_args
.
aux_data_ptrs
prefill_aux_item_lens
=
self
.
kv_args
.
aux_item_lens
for
i
in
range
(
len
(
prefill_aux_ptrs
)):
length
=
prefill_aux_item_lens
[
i
]
src_addr
=
prefill_aux_ptrs
[
i
]
+
length
*
prefill_aux_index
data
=
AuxDataCodec
.
serialize_data_from_buffer
(
src_addr
,
length
)
self
.
send_aux_data_to_endpoint
(
remote
=
req
.
endpoint
,
dst_port
=
req
.
dst_port
,
room
=
req
.
room
,
buffer_index
=
i
,
aux_index
=
req
.
dst_aux_index
,
data
=
data
,
)
return
0
def
send_aux_data_to_endpoint
(
self
,
remote
:
str
,
dst_port
:
int
,
room
:
int
,
buffer_index
:
int
,
aux_index
:
int
,
data
:
bytes
,
):
socket
=
self
.
_connect
(
format_tcp_address
(
remote
,
dst_port
),
is_ipv6
=
is_valid_ipv6_address
(
remote
)
)
socket
.
send_multipart
(
[
MooncakeKVManager
.
AUX_DATA_HEADER
,
str
(
room
).
encode
(
"ascii"
),
str
(
buffer_index
).
encode
(
"ascii"
),
str
(
aux_index
).
encode
(
"ascii"
),
struct
.
pack
(
">I"
,
len
(
data
)),
data
,
]
)
def
sync_status_to_decode_endpoint
(
def
sync_status_to_decode_endpoint
(
self
,
remote
:
str
,
dst_port
:
int
,
room
:
int
,
status
:
int
,
prefill_rank
:
int
self
,
remote
:
str
,
dst_port
:
int
,
room
:
int
,
status
:
int
,
prefill_rank
:
int
...
@@ -699,10 +763,9 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -699,10 +763,9 @@ class MooncakeKVManager(BaseKVManager):
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
# Only the last chunk we need to send the aux data
# Only the last chunk we need to send the aux data
ret
=
self
.
send_aux
(
ret
=
self
.
send_aux
(
req
.
mooncake_session_id
,
req
,
kv_chunk
.
prefill_aux_index
,
kv_chunk
.
prefill_aux_index
,
target_rank_registration_info
.
dst_aux_ptrs
,
target_rank_registration_info
.
dst_aux_ptrs
,
req
.
dst_aux_index
,
)
)
polls
.
append
(
True
if
ret
==
0
else
False
)
polls
.
append
(
True
if
ret
==
0
else
False
)
dst_ranks_infos
.
append
(
dst_ranks_infos
.
append
(
...
@@ -778,15 +841,38 @@ class MooncakeKVManager(BaseKVManager):
...
@@ -778,15 +841,38 @@ class MooncakeKVManager(BaseKVManager):
threading
.
Thread
(
target
=
bootstrap_thread
).
start
()
threading
.
Thread
(
target
=
bootstrap_thread
).
start
()
def
_handle_aux_data
(
self
,
msg
:
List
[
bytes
]):
"""Handle AUX_DATA messages received by the decode thread."""
room
=
int
(
msg
[
1
].
decode
(
"ascii"
))
buffer_index
=
int
(
msg
[
2
].
decode
(
"ascii"
))
aux_index
=
int
(
msg
[
3
].
decode
(
"ascii"
))
data_length
=
struct
.
unpack
(
">I"
,
msg
[
4
])[
0
]
data
=
msg
[
5
]
if
len
(
data
)
!=
data_length
:
logger
.
error
(
f
"AUX_DATA length mismatch for bootstrap_room
{
room
}
"
)
return
AuxDataCodec
.
deserialize_data_to_buffer
(
self
.
kv_args
,
buffer_index
,
aux_index
,
data
)
logger
.
debug
(
f
"Received AUX_DATA for bootstrap_room
{
room
}
with length:
{
len
(
data
)
}
"
)
def
start_decode_thread
(
self
):
def
start_decode_thread
(
self
):
self
.
rank_port
=
get_free_port
()
self
.
rank_port
=
get_free_port
()
self
.
_bind_server_socket
()
self
.
_bind_server_socket
()
def
decode_thread
():
def
decode_thread
():
while
True
:
while
True
:
(
bootstrap_room
,
status
,
prefill_rank
)
=
(
msg
=
self
.
server_socket
.
recv_multipart
()
self
.
server_socket
.
recv_multipart
()
if
msg
[
0
]
==
MooncakeKVManager
.
AUX_DATA_HEADER
:
)
self
.
_handle_aux_data
(
msg
)
continue
(
bootstrap_room
,
status
,
prefill_rank
)
=
msg
status
=
int
(
status
.
decode
(
"ascii"
))
status
=
int
(
status
.
decode
(
"ascii"
))
bootstrap_room
=
int
(
bootstrap_room
.
decode
(
"ascii"
))
bootstrap_room
=
int
(
bootstrap_room
.
decode
(
"ascii"
))
prefill_rank
=
int
(
prefill_rank
.
decode
(
"ascii"
))
prefill_rank
=
int
(
prefill_rank
.
decode
(
"ascii"
))
...
...
python/sglang/srt/disaggregation/utils.py
View file @
25ef53f0
...
@@ -99,7 +99,8 @@ class MetadataBuffers:
...
@@ -99,7 +99,8 @@ class MetadataBuffers:
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
device
=
"npu"
device
=
"npu"
elif
self
.
custom_mem_pool
:
elif
self
.
custom_mem_pool
:
device
=
"cuda"
# TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
device
=
"cpu"
with
(
with
(
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
if
self
.
custom_mem_pool
if
self
.
custom_mem_pool
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment