Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9ab72f98
Unverified
Commit
9ab72f98
authored
Sep 09, 2025
by
shaharmor98
Committed by
GitHub
Sep 09, 2025
Browse files
add variable TP Decode > Prefill size support (#9960)
Signed-off-by:
Shahar Mor
<
smor@nvidia.com
>
parent
f3817cb0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
181 additions
and
18 deletions
+181
-18
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+0
-3
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+3
-1
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+178
-14
No files found.
python/sglang/srt/disaggregation/common/conn.py
View file @
9ab72f98
...
...
@@ -168,9 +168,6 @@ class CommonKVReceiver(BaseKVReceiver):
self
.
required_dst_info_num
=
1
self
.
target_tp_ranks
=
[
self
.
target_tp_rank
]
elif
local_tp_size_per_dp_rank
>
prefill_tp_size_per_dp_rank
:
assert
(
self
.
kv_mgr
.
is_mla_backend
),
"PD with different TP sizes per DP rank is not yet supported for non-MLA models"
self
.
target_tp_rank
=
(
self
.
kv_mgr
.
kv_args
.
engine_rank
%
local_tp_size_per_dp_rank
)
//
(
local_tp_size_per_dp_rank
//
prefill_tp_size_per_dp_rank
)
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
9ab72f98
...
...
@@ -459,7 +459,9 @@ class MooncakeKVManager(BaseKVManager):
dst_head_start_offset
=
local_tp_rank_in_group
*
src_heads_per_rank
else
:
# Send KVCache from 1 prefill instance to multiple decode instances
src_head_start_offset
=
dst_tp_rank_in_group
*
dst_heads_per_rank
src_head_start_offset
=
(
dst_tp_rank_in_group
*
dst_heads_per_rank
)
%
src_heads_per_rank
num_heads_to_send
=
dst_heads_per_rank
dst_head_start_offset
=
0
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
9ab72f98
...
...
@@ -78,6 +78,9 @@ class KVArgsRegisterInfo:
dst_kv_ptrs
:
list
[
int
]
dst_aux_ptrs
:
list
[
int
]
gpu_id
:
int
decode_tp_size
:
int
decode_tp_rank
:
int
dst_kv_item_len
:
int
@
classmethod
def
from_zmq
(
cls
,
msg
:
List
[
bytes
]):
...
...
@@ -90,6 +93,9 @@ class KVArgsRegisterInfo:
dst_kv_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
5
])
//
8
}
Q"
,
msg
[
5
])),
dst_aux_ptrs
=
list
(
struct
.
unpack
(
f
"
{
len
(
msg
[
6
])
//
8
}
Q"
,
msg
[
6
])),
gpu_id
=
int
(
msg
[
7
].
decode
(
"ascii"
)),
decode_tp_size
=
int
(
msg
[
8
].
decode
(
"ascii"
)),
decode_tp_rank
=
int
(
msg
[
9
].
decode
(
"ascii"
)),
dst_kv_item_len
=
int
(
msg
[
10
].
decode
(
"ascii"
)),
)
...
...
@@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager):
self
.
kv_args
.
kv_data_ptrs
,
self
.
kv_args
.
kv_data_lens
):
kv_addrs
.
append
((
kv_data_ptr
,
kv_data_len
,
self
.
kv_args
.
gpu_id
,
""
))
self
.
kv_descs
=
self
.
agent
.
register_memory
(
kv_addrs
,
"VRAM"
,
is_sorted
=
False
)
self
.
kv_descs
=
self
.
agent
.
register_memory
(
kv_addrs
,
"VRAM"
)
logger
.
debug
(
f
"Register kv tensors, len(kv_addr)=
{
len
(
kv_addrs
)
}
"
)
if
not
self
.
kv_descs
:
raise
Exception
(
"NIXL memory registration failed for kv tensors"
)
...
...
@@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager):
self
.
kv_args
.
aux_data_ptrs
,
self
.
kv_args
.
aux_data_lens
):
aux_addrs
.
append
((
aux_data_ptr
,
aux_data_len
,
0
,
""
))
self
.
aux_descs
=
self
.
agent
.
register_memory
(
aux_addrs
,
"DRAM"
,
is_sorted
=
False
)
self
.
aux_descs
=
self
.
agent
.
register_memory
(
aux_addrs
,
"DRAM"
)
logger
.
debug
(
f
"Register aux tensors, len(aux_addrs)=
{
len
(
aux_addrs
)
}
"
)
if
not
self
.
aux_descs
:
raise
Exception
(
"NIXL memory registration failed for aux tensors"
)
...
...
@@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager):
logger
.
debug
(
f
"len(src_addrs): before group:
{
len
(
prefill_kv_indices
)
}
, after group:
{
len
(
src_addrs
)
}
"
)
src_descs
=
self
.
agent
.
get_xfer_descs
(
src_addrs
,
"VRAM"
,
is_sorted
=
False
)
dst_descs
=
self
.
agent
.
get_xfer_descs
(
dst_addrs
,
"VRAM"
,
is_sorted
=
False
)
src_descs
=
self
.
agent
.
get_xfer_descs
(
src_addrs
,
"VRAM"
)
dst_descs
=
self
.
agent
.
get_xfer_descs
(
dst_addrs
,
"VRAM"
)
# Transfer data
xfer_handle
=
self
.
agent
.
initialize_xfer
(
"WRITE"
,
...
...
@@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager):
raise
Exception
(
"KVSender failed to post transfer"
)
return
xfer_handle
def
send_kvcache_slice
(
self
,
peer_name
:
str
,
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int32
],
dst_kv_ptrs
:
list
[
int
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int32
],
dst_gpu_id
:
int
,
notif
:
str
,
prefill_tp_size
:
int
,
decode_tp_size
:
int
,
decode_tp_rank
:
int
,
dst_kv_item_len
:
int
,
):
# Get configuration from kv_args
local_tp_rank_in_group
=
self
.
kv_args
.
engine_rank
%
prefill_tp_size
dst_tp_rank_in_group
=
decode_tp_rank
%
decode_tp_size
num_kv_heads
=
self
.
kv_args
.
kv_head_num
# Calculate head distribution
src_heads_per_rank
=
num_kv_heads
dst_heads_per_rank
=
num_kv_heads
*
prefill_tp_size
//
decode_tp_size
src_kv_item_len
=
self
.
kv_args
.
kv_item_lens
[
0
]
page_size
=
self
.
kv_args
.
page_size
bytes_per_head_slice_to_send
=
(
dst_kv_item_len
//
page_size
//
dst_heads_per_rank
)
# Determine which heads to send
if
prefill_tp_size
>
decode_tp_size
:
# Multiple prefill ranks to one decode rank
src_head_start_offset
=
0
num_heads_to_send
=
src_heads_per_rank
dst_head_start_offset
=
local_tp_rank_in_group
*
src_heads_per_rank
else
:
# Send KVCache from 1 prefill instance to multiple decode instances
src_head_start_offset
=
(
dst_tp_rank_in_group
*
dst_heads_per_rank
)
%
src_heads_per_rank
num_heads_to_send
=
dst_heads_per_rank
dst_head_start_offset
=
0
# Create transfer descriptors
src_addrs
=
[]
dst_addrs
=
[]
bytes_per_token_on_prefill
=
src_kv_item_len
//
page_size
bytes_per_token_on_decode
=
dst_kv_item_len
//
page_size
num_kv_layers
=
len
(
self
.
kv_args
.
kv_data_ptrs
)
//
2
src_k_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[:
num_kv_layers
]
src_v_ptrs
=
self
.
kv_args
.
kv_data_ptrs
[
num_kv_layers
:]
dst_k_ptrs
=
dst_kv_ptrs
[
0
:
len
(
src_k_ptrs
)]
dst_v_ptrs
=
dst_kv_ptrs
[
num_kv_layers
:
num_kv_layers
+
len
(
src_v_ptrs
)]
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset
=
src_head_start_offset
*
bytes_per_head_slice_to_send
dst_head_slice_offset
=
dst_head_start_offset
*
bytes_per_head_slice_to_send
heads_bytes_per_token_to_send
=
num_heads_to_send
*
bytes_per_head_slice_to_send
src_dst_ptr_pairs
=
[
(
src_k_ptrs
[
layer_id
],
dst_k_ptrs
[
layer_id
],
)
for
layer_id
in
range
(
len
(
src_k_ptrs
))
]
+
[
(
src_v_ptrs
[
layer_id
],
dst_v_ptrs
[
layer_id
],
)
for
layer_id
in
range
(
len
(
src_v_ptrs
))
]
src_addrs
=
[]
dst_addrs
=
[]
# Calculate strides for a single token slot
bytes_per_token_on_prefill
=
src_kv_item_len
//
page_size
bytes_per_token_on_decode
=
dst_kv_item_len
//
page_size
for
src_ptr
,
dst_ptr
in
src_dst_ptr_pairs
:
for
i
in
range
(
len
(
prefill_kv_indices
)):
prefill_page_idx
=
int
(
prefill_kv_indices
[
i
])
decode_page_idx
=
int
(
dst_kv_indices
[
i
])
# Get the starting addresses for the current src and dst pages
src_page_start_addr
=
src_ptr
+
prefill_page_idx
*
src_kv_item_len
dst_page_start_addr
=
dst_ptr
+
decode_page_idx
*
dst_kv_item_len
# Iterate through each valid token slot within the current page
for
token_slot_in_page
in
range
(
page_size
):
# Calculate the start address of the current token slot
src_token_slot_start_addr
=
(
src_page_start_addr
+
token_slot_in_page
*
bytes_per_token_on_prefill
)
dst_token_slot_start_addr
=
(
dst_page_start_addr
+
token_slot_in_page
*
bytes_per_token_on_decode
)
# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr
=
src_token_slot_start_addr
+
src_head_slice_offset
dst_slice_addr
=
dst_token_slot_start_addr
+
dst_head_slice_offset
src_addrs
.
append
(
(
src_slice_addr
,
heads_bytes_per_token_to_send
,
self
.
kv_args
.
gpu_id
,
)
)
dst_addrs
.
append
(
(
dst_slice_addr
,
heads_bytes_per_token_to_send
,
dst_gpu_id
)
)
# Use NIXL agent for transfer
src_descs
=
self
.
agent
.
get_xfer_descs
(
src_addrs
,
"VRAM"
)
dst_descs
=
self
.
agent
.
get_xfer_descs
(
dst_addrs
,
"VRAM"
)
xfer_handle
=
self
.
agent
.
initialize_xfer
(
"WRITE"
,
src_descs
,
dst_descs
,
peer_name
,
notif
.
encode
(
"ascii"
)
)
if
not
xfer_handle
:
raise
Exception
(
"Failed to create sliced KV transfer"
)
state
=
self
.
agent
.
transfer
(
xfer_handle
)
if
state
==
"ERR"
:
raise
Exception
(
"Failed to post sliced KV transfer"
)
return
xfer_handle
def
send_aux
(
self
,
peer_name
:
str
,
...
...
@@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager):
decode_aux_addr
=
dst_aux_ptrs
[
0
]
+
dst_aux_index
*
aux_item_len
src_addrs
=
[(
prefill_aux_addr
,
aux_item_len
,
0
)]
dst_addrs
=
[(
decode_aux_addr
,
aux_item_len
,
0
)]
src_descs
=
self
.
agent
.
get_xfer_descs
(
src_addrs
,
"DRAM"
,
is_sorted
=
False
)
dst_descs
=
self
.
agent
.
get_xfer_descs
(
dst_addrs
,
"DRAM"
,
is_sorted
=
False
)
src_descs
=
self
.
agent
.
get_xfer_descs
(
src_addrs
,
"DRAM"
)
dst_descs
=
self
.
agent
.
get_xfer_descs
(
dst_addrs
,
"DRAM"
)
# Transfer data
xfer_handle
=
self
.
agent
.
initialize_xfer
(
"WRITE"
,
...
...
@@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager):
assert
req
.
agent_name
in
self
.
decode_kv_args_table
notif
=
"_"
.
join
([
str
(
req
.
room
),
"kv"
,
str
(
chunk_id
),
str
(
int
(
is_last
))])
kv_xfer_handle
=
self
.
send_kvcache
(
req
.
agent_name
,
kv_indices
,
self
.
decode_kv_args_table
[
req
.
agent_name
].
dst_kv_ptrs
,
chunked_dst_kv_indice
,
self
.
decode_kv_args_table
[
req
.
agent_name
].
gpu_id
,
notif
,
)
decode_tp_size
=
self
.
decode_kv_args_table
[
req
.
agent_name
].
decode_tp_size
if
decode_tp_size
==
self
.
tp_size
:
kv_xfer_handle
=
self
.
send_kvcache
(
req
.
agent_name
,
kv_indices
,
self
.
decode_kv_args_table
[
req
.
agent_name
].
dst_kv_ptrs
,
chunked_dst_kv_indice
,
self
.
decode_kv_args_table
[
req
.
agent_name
].
gpu_id
,
notif
,
)
else
:
kv_xfer_handle
=
self
.
send_kvcache_slice
(
req
.
agent_name
,
kv_indices
,
self
.
decode_kv_args_table
[
req
.
agent_name
].
dst_kv_ptrs
,
chunked_dst_kv_indice
,
self
.
decode_kv_args_table
[
req
.
agent_name
].
gpu_id
,
notif
,
prefill_tp_size
=
self
.
tp_size
,
decode_tp_size
=
decode_tp_size
,
decode_tp_rank
=
self
.
decode_kv_args_table
[
req
.
agent_name
].
decode_tp_rank
,
dst_kv_item_len
=
self
.
decode_kv_args_table
[
req
.
agent_name
].
dst_kv_item_len
,
)
handles
.
append
(
kv_xfer_handle
)
# Only the last chunk we need to send the aux data.
if
is_last
:
...
...
@@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver):
packed_kv_data_ptrs
,
packed_aux_data_ptrs
,
str
(
self
.
kv_mgr
.
kv_args
.
gpu_id
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
decode_tp_size
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
engine_rank
).
encode
(
"ascii"
),
str
(
self
.
kv_mgr
.
kv_args
.
kv_item_lens
[
0
]).
encode
(
"ascii"
),
]
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment