Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a7ef3eb0
Unverified
Commit
a7ef3eb0
authored
Nov 11, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Nov 11, 2025
Browse files
[NIXL] Generalize block-first backend layouts (FlashInfer-like) (#28282)
parent
f9a40871
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
12 deletions
+52
-12
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+15
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+37
-10
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
a7ef3eb0
...
@@ -1096,7 +1096,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
...
@@ -1096,7 +1096,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
llm
.
llm_engine
.
engine_core
.
shutdown
()
llm
.
llm_engine
.
engine_core
.
shutdown
()
def
test_register_kv_caches
(
dist_init
):
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
[
"FLASH_ATTN"
,
"TRITON_ATTN"
])
def
test_register_kv_caches
(
dist_init
,
attn_backend
,
monkeypatch
):
"""
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
correct data.
...
@@ -1108,10 +1109,22 @@ def test_register_kv_caches(dist_init):
...
@@ -1108,10 +1109,22 @@ def test_register_kv_caches(dist_init):
block layout info
block layout info
"""
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
# Import the appropriate backend based on the parameter
if
attn_backend
==
"FLASH_ATTN"
:
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
backend_cls
=
FlashAttentionBackend
else
:
# TRITON_ATTN
from
vllm.v1.attention.backends.triton_attn
import
TritonAttentionBackend
backend_cls
=
TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
# Create test kv cache tensors using proper backend shape
kv_cache_shape
=
FlashAttentionB
ackend
.
get_kv_cache_shape
(
kv_cache_shape
=
b
ackend
_cls
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
a7ef3eb0
...
@@ -21,6 +21,7 @@ import torch
...
@@ -21,6 +21,7 @@ import torch
import
zmq
import
zmq
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention
import
AttentionBackend
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -669,6 +670,33 @@ class NixlConnectorWorker:
...
@@ -669,6 +670,33 @@ class NixlConnectorWorker:
remote_tp_size
:
dict
[
EngineId
,
int
]
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
is_mla
:
bool
total_num_kv_heads
:
int
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self
.
_is_kv_layout_blocks_first
=
(
len
(
kv_cache_shape
)
==
5
and
kv_cache_shape
[
0
]
==
1
)
attn_backend
=
AttentionBackendEnum
[
self
.
attn_backend
.
get_name
()]
self
.
_use_pallas
=
attn_backend
==
AttentionBackendEnum
.
PALLAS
@
property
def
is_kv_layout_blocks_first
(
self
)
->
bool
:
return
self
.
_is_kv_layout_blocks_first
@
property
def
split_k_and_v
(
self
)
->
bool
:
# Whether to register regions for K and V separately (when present).
return
not
(
self
.
is_mla
or
self
.
_use_pallas
or
self
.
is_kv_layout_blocks_first
)
def
tp_ratio
(
def
tp_ratio
(
self
,
self
,
...
@@ -876,9 +904,6 @@ class NixlConnectorWorker:
...
@@ -876,9 +904,6 @@ class NixlConnectorWorker:
use_mla
=
self
.
use_mla
,
use_mla
=
self
.
use_mla
,
)
)
self
.
backend_name
=
backend
.
get_name
()
self
.
backend_name
=
backend
.
get_name
()
attn_backend
=
AttentionBackendEnum
[
self
.
backend_name
]
self
.
_use_flashinfer
=
attn_backend
==
AttentionBackendEnum
.
FLASHINFER
self
.
_use_pallas
=
attn_backend
==
AttentionBackendEnum
.
PALLAS
self
.
kv_cache_layout
=
get_kv_cache_layout
()
self
.
kv_cache_layout
=
get_kv_cache_layout
()
self
.
host_buffer_kv_cache_layout
=
self
.
kv_cache_layout
self
.
host_buffer_kv_cache_layout
=
self
.
kv_cache_layout
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
...
@@ -896,7 +921,9 @@ class NixlConnectorWorker:
...
@@ -896,7 +921,9 @@ class NixlConnectorWorker:
remote_tp_size
=
self
.
_tp_size
,
# shared state
remote_tp_size
=
self
.
_tp_size
,
# shared state
is_mla
=
self
.
use_mla
,
is_mla
=
self
.
use_mla
,
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
total_num_kv_heads
=
self
.
model_config
.
get_total_num_kv_heads
(),
attn_backend
=
backend
,
)
)
self
.
_use_pallas
=
self
.
kv_topo
.
_use_pallas
def
_nixl_handshake
(
def
_nixl_handshake
(
self
,
self
,
...
@@ -1076,7 +1103,7 @@ class NixlConnectorWorker:
...
@@ -1076,7 +1103,7 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
_use_pallas
or
self
.
_use_flashinfer
)
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
tensor_size_bytes
=
None
tensor_size_bytes
=
None
# Enable different block lengths for different layers when MLA is used.
# Enable different block lengths for different layers when MLA is used.
self
.
block_len_per_layer
=
list
[
int
]()
self
.
block_len_per_layer
=
list
[
int
]()
...
@@ -1141,7 +1168,7 @@ class NixlConnectorWorker:
...
@@ -1141,7 +1168,7 @@ class NixlConnectorWorker:
self
.
device_kv_caches
=
kv_caches
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
if
self
.
_use_flashinfer
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
for
i
in
range
(
len
(
self
.
slot_size_per_layer
)):
for
i
in
range
(
len
(
self
.
slot_size_per_layer
)):
assert
self
.
slot_size_per_layer
[
i
]
%
2
==
0
assert
self
.
slot_size_per_layer
[
i
]
%
2
==
0
self
.
slot_size_per_layer
[
i
]
//=
2
self
.
slot_size_per_layer
[
i
]
//=
2
...
@@ -1169,7 +1196,7 @@ class NixlConnectorWorker:
...
@@ -1169,7 +1196,7 @@ class NixlConnectorWorker:
# (addr, len, device id)
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv_block_len
,
self
.
device_id
))
blocks_data
.
append
((
addr
,
kv_block_len
,
self
.
device_id
))
if
self
.
_use_flashinfer
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# Separate and interleave K/V regions to maintain the same
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
# when split across TP ranks.
...
@@ -1331,7 +1358,7 @@ class NixlConnectorWorker:
...
@@ -1331,7 +1358,7 @@ class NixlConnectorWorker:
# (addr, len, device id)
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv_block_len
,
nixl_agent_meta
.
device_id
))
blocks_data
.
append
((
addr
,
kv_block_len
,
nixl_agent_meta
.
device_id
))
if
self
.
_use_flashinfer
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# With FlashInfer index V separately to allow head splitting.
# With FlashInfer index V separately to allow head splitting.
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
block_offset
=
block_id
*
nixl_agent_meta
.
block_lens
[
i
]
...
@@ -1414,7 +1441,7 @@ class NixlConnectorWorker:
...
@@ -1414,7 +1441,7 @@ class NixlConnectorWorker:
remote_block_size
=
remote_block_len
//
(
remote_block_size
=
remote_block_len
//
(
self
.
slot_size_per_layer
[
0
]
*
tp_ratio
self
.
slot_size_per_layer
[
0
]
*
tp_ratio
)
)
if
self
.
_use_flashinfer
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# With flashinfer, KV are sent in the same message.
# With flashinfer, KV are sent in the same message.
remote_block_size
//=
2
remote_block_size
//=
2
...
@@ -1494,7 +1521,7 @@ class NixlConnectorWorker:
...
@@ -1494,7 +1521,7 @@ class NixlConnectorWorker:
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
"""
"""
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
_use_pallas
or
self
.
_use_flashinfer
)
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
inv_order
=
[
0
,
2
,
1
,
3
]
inv_order
=
[
0
,
2
,
1
,
3
]
sample_cache
=
list
(
self
.
device_kv_caches
.
values
())[
0
][
0
]
sample_cache
=
list
(
self
.
device_kv_caches
.
values
())[
0
][
0
]
target_shape
=
list
(
sample_cache
.
shape
)
target_shape
=
list
(
sample_cache
.
shape
)
...
@@ -1874,7 +1901,7 @@ class NixlConnectorWorker:
...
@@ -1874,7 +1901,7 @@ class NixlConnectorWorker:
For FlashInfer, this is half the length of the whole block, as K and V
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
share the same region.
"""
"""
if
self
.
_use_flashinfer
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# For indexing only half (either just the K or V part).
# For indexing only half (either just the K or V part).
block_len
=
self
.
block_len_per_layer
[
layer_idx
]
//
2
block_len
=
self
.
block_len_per_layer
[
layer_idx
]
//
2
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment