Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f0c503f6
Unverified
Commit
f0c503f6
authored
Sep 03, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Sep 03, 2025
Browse files
[Nixl] Heterogeneous TP support FlashInfer (#20189)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
f38035c1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
9 deletions
+53
-9
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+53
-9
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
f0c503f6
...
@@ -715,7 +715,7 @@ class NixlConnectorWorker:
...
@@ -715,7 +715,7 @@ class NixlConnectorWorker:
# are non-contiguous (it's not locally guaranteed that they will be)
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are
transfer
red in the same
tensor
# Conversely for FlashInfer, K and V are
registe
red in the same
region
# to better exploit the memory layout (ie num_blocks is the first dim).
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
_use_pallas_v1
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
_use_pallas_v1
or
self
.
_use_flashinfer
)
or
self
.
_use_flashinfer
)
...
@@ -758,12 +758,21 @@ class NixlConnectorWorker:
...
@@ -758,12 +758,21 @@ class NixlConnectorWorker:
assert
tensor_size_bytes
%
self
.
num_blocks
==
0
assert
tensor_size_bytes
%
self
.
num_blocks
==
0
self
.
block_len
=
tensor_size_bytes
//
self
.
num_blocks
self
.
block_len
=
tensor_size_bytes
//
self
.
num_blocks
self
.
slot_size_bytes
=
self
.
block_len
//
self
.
block_size
self
.
slot_size_bytes
=
self
.
block_len
//
self
.
block_size
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
if
self
.
_use_flashinfer
:
if
self
.
_use_flashinfer
:
assert
self
.
slot_size_bytes
%
2
==
0
assert
self
.
slot_size_bytes
%
2
==
0
self
.
slot_size_bytes
/=
2
self
.
slot_size_bytes
/=
2
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence the we double the number
# of 'virtual' regions here and halve `block_len` below.
self
.
num_regions
*=
2
kv_block_len
=
self
.
get_backend_aware_kv_block_len
()
# Register local/src descr for NIXL xfer.
# Register local/src descr for NIXL xfer.
blocks_data
=
[]
blocks_data
=
[]
for
base_addr
in
seen_base_addresses
:
for
base_addr
in
seen_base_addresses
:
...
@@ -776,8 +785,18 @@ class NixlConnectorWorker:
...
@@ -776,8 +785,18 @@ class NixlConnectorWorker:
block_offset
=
block_id
*
self
.
block_len
block_offset
=
block_id
*
self
.
block_len
addr
=
base_addr
+
block_offset
addr
=
base_addr
+
block_offset
# (addr, len, device id)
# (addr, len, device id)
# TODO: does device_id matter to DRAM?
blocks_data
.
append
((
addr
,
kv_block_len
,
self
.
tp_rank
))
blocks_data
.
append
((
addr
,
self
.
block_len
,
self
.
tp_rank
))
if
self
.
_use_flashinfer
:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for
block_id
in
range
(
self
.
num_blocks
):
block_offset
=
block_id
*
self
.
block_len
addr
=
base_addr
+
block_offset
# Register addresses for V cache (K registered first).
v_addr
=
addr
+
kv_block_len
blocks_data
.
append
((
v_addr
,
kv_block_len
,
self
.
tp_rank
))
logger
.
debug
(
"Created %s blocks for src engine %s and rank %s"
,
logger
.
debug
(
"Created %s blocks for src engine %s and rank %s"
,
len
(
blocks_data
),
self
.
engine_id
,
self
.
tp_rank
)
len
(
blocks_data
),
self
.
engine_id
,
self
.
tp_rank
)
...
@@ -903,7 +922,7 @@ class NixlConnectorWorker:
...
@@ -903,7 +922,7 @@ class NixlConnectorWorker:
remote_block_size
=
nixl_agent_meta
.
block_len
//
(
remote_block_size
=
nixl_agent_meta
.
block_len
//
(
self
.
slot_size_bytes
*
tp_ratio
)
self
.
slot_size_bytes
*
tp_ratio
)
if
self
.
_use_flashinfer
:
if
self
.
_use_flashinfer
:
#
Account for joint KV in FlashInfer
.
#
With flashinfer, KV are sent in the same message
.
remote_block_size
//=
2
remote_block_size
//=
2
if
tp_ratio
>
1
:
if
tp_ratio
>
1
:
# Heterogeneous TP expects same kv_cache_layout.
# Heterogeneous TP expects same kv_cache_layout.
...
@@ -929,10 +948,10 @@ class NixlConnectorWorker:
...
@@ -929,10 +948,10 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
# Only register the remote's descriptors if current rank pulls from it.
self
.
kv_caches_base_addr
[
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
self
.
block_len
\
kv_block_len
=
self
.
get_backend_aware_kv_block_len
()
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
kv_block_len
\
if
not
(
self
.
use_mla
or
is_kv_replicated
)
else
0
if
not
(
self
.
use_mla
or
is_kv_replicated
)
else
0
# Register all remote blocks, but only the corresponding kv heads.
# Register all remote blocks, but only the corresponding kv heads.
for
base_addr
in
nixl_agent_meta
.
kv_caches_base_addr
:
for
base_addr
in
nixl_agent_meta
.
kv_caches_base_addr
:
...
@@ -943,7 +962,16 @@ class NixlConnectorWorker:
...
@@ -943,7 +962,16 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
# (addr, len, device id)
blocks_data
.
append
((
addr
,
self
.
block_len
,
remote_tp_rank
))
blocks_data
.
append
((
addr
,
kv_block_len
,
remote_tp_rank
))
if
self
.
_use_flashinfer
:
# With FlashInfer index V separately to allow head splitting.
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_len
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_len
//
2
blocks_data
.
append
((
v_addr
,
kv_block_len
,
remote_tp_rank
))
logger
.
debug
(
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and "
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
"local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
...
@@ -1249,6 +1277,22 @@ class NixlConnectorWorker:
...
@@ -1249,6 +1277,22 @@ class NixlConnectorWorker:
descs_ids
.
append
(
reg_id
*
num_blocks
+
block_id
)
descs_ids
.
append
(
reg_id
*
num_blocks
+
block_id
)
return
descs_ids
return
descs_ids
def
get_backend_aware_kv_block_len
(
self
):
"""
Get the block length for one K/V element (K and V have the same size).
For FA and other backends, this is equal to the length of the whole
block, as K and V are in separate regions.
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
if
self
.
_use_flashinfer
:
# For indexing only half (either just the K or V part).
block_len
=
self
.
block_len
//
2
else
:
block_len
=
self
.
block_len
return
block_len
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
zmq_ctx
(
socket_type
:
Any
,
addr
:
str
)
->
Iterator
[
zmq
.
Socket
]:
def
zmq_ctx
(
socket_type
:
Any
,
addr
:
str
)
->
Iterator
[
zmq
.
Socket
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment