Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53415653
Unverified
Commit
53415653
authored
Aug 21, 2025
by
Flora Feng
Committed by
GitHub
Aug 21, 2025
Browse files
[P/D][Nixl] Make kv cache register compatible with hybrid memory allocator (#23079)
Signed-off-by:
sfeng33
<
4florafeng@gmail.com
>
parent
17373dcd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
95 deletions
+150
-95
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+85
-1
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+63
-92
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
53415653
...
...
@@ -14,6 +14,7 @@ from unittest.mock import patch
import
pytest
import
ray
import
torch
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
...
...
@@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker
)
from
vllm.forward_context
import
ForwardContext
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
...
...
@@ -98,7 +100,6 @@ class FakeNixlWrapper:
def
set_cycles_before_xfer_done
(
self
,
cycles
:
int
):
"""Set the number of cycles before a transfer is considered done."""
self
.
_cycles_before_xfer_done
=
cycles
@
contextlib
.
contextmanager
...
...
@@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
sampling_params
)
# Request-0 times out and is cleared!
assert
'0'
not
in
req_to_blocks
def
test_register_kv_caches
(
dist_init
):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
This test verifies:
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
tensor metadata
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
block layout info
"""
vllm_config
=
create_vllm_config
()
# Create test kv cache tensors using proper backend shape
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
# Store tensor info for validation
expected_tensor_size
=
shared_tensor
[
0
].
element_size
(
)
*
shared_tensor
[
0
].
numel
()
expected_base_addrs
=
[
shared_tensor
[
0
].
data_ptr
(),
shared_tensor
[
1
].
data_ptr
(),
unique_tensor
[
0
].
data_ptr
(),
unique_tensor
[
1
].
data_ptr
()
]
with
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
)
as
mock_nixl_wrapper
,
\
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
),
\
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
):
# noqa: E501
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
# Get the mock instance
mock_wrapper_instance
=
mock_nixl_wrapper
.
return_value
connector
.
connector_worker
.
nixl_wrapper
=
mock_wrapper_instance
# Execute register_kv_caches
connector
.
register_kv_caches
(
kv_caches
)
# Verify get_reg_descs was called with caches_data
assert
mock_wrapper_instance
.
get_reg_descs
.
called
caches_data
,
_
=
mock_wrapper_instance
.
get_reg_descs
.
call_args
[
0
]
assert
len
(
caches_data
)
==
4
for
i
,
cache_entry
in
enumerate
(
caches_data
):
base_addr
,
size
,
_tp_rank
,
_
=
cache_entry
assert
size
==
expected_tensor_size
,
\
f
"Entry
{
i
}
: Expected tensor size
{
expected_tensor_size
}
, "
\
f
"got
{
size
}
"
assert
base_addr
==
expected_base_addrs
[
i
],
\
f
"Entry
{
i
}
: Expected base address
{
expected_base_addrs
[
i
]
}
, "
\
f
"got
{
base_addr
}
"
# Verify get_xfer_descs was called with blocks_data
assert
mock_wrapper_instance
.
get_xfer_descs
.
called
blocks_data
,
_
=
mock_wrapper_instance
.
get_xfer_descs
.
call_args
[
0
]
# Validate blocks_data structure and size
expected_blocks_count
=
8
assert
len
(
blocks_data
)
==
expected_blocks_count
,
\
f
"Expected
{
expected_blocks_count
}
blocks, "
\
f
"got
{
len
(
blocks_data
)
}
"
expected_block_len
=
expected_tensor_size
//
2
for
i
,
block_entry
in
enumerate
(
blocks_data
):
block_start_addr
,
block_len
,
tp_rank
=
block_entry
assert
block_len
==
expected_block_len
,
\
f
"Block entry
{
i
}
: Expected block len
{
expected_block_len
}
, "
\
f
"got
{
block_len
}
"
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
53415653
...
...
@@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches:
dictionary of layer names, kv cache
Args:
kv_caches:
dictionary of layer names, kv cache
"""
return
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
53415653
...
...
@@ -686,9 +686,6 @@ class NixlConnectorWorker:
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""Register the KV Cache data in nixl."""
_
,
first_kv_cache
=
next
(
iter
(
kv_caches
.
items
()))
kv_elem_size
=
first_kv_cache
.
element_size
()
if
self
.
use_host_buffer
:
self
.
initialize_host_xfer_buffer
(
kv_caches
=
kv_caches
)
assert
len
(
self
.
host_xfer_buffers
)
==
len
(
kv_caches
),
(
...
...
@@ -701,66 +698,16 @@ class NixlConnectorWorker:
"host_xfer_buffer should not be initialized when "
f
"kv_buffer_device is
{
self
.
kv_buffer_device
}
"
)
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla
=
len
(
first_kv_cache
.
shape
)
==
3
if
self
.
device_type
==
"tpu"
:
assert
not
use_mla
,
f
"
{
self
.
kv_buffer_device
}
does not support MLA."
assert
self
.
_use_pallas_v1
,
f
"attn backend:
{
self
.
backend_name
}
"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
3
# [block_size, kv_heads, head_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
n_kv_heads_x_2
,
head_dim
=
block_shape
self
.
slot_size_bytes
=
kv_elem_size
*
n_kv_heads_x_2
*
head_dim
elif
self
.
device_type
==
"cuda"
:
assert
use_mla
==
self
.
use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if
use_mla
:
# MLA case.
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
2
# [block_size, latent_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
kv_latent_dim
=
block_shape
self
.
slot_size_bytes
=
kv_elem_size
*
kv_latent_dim
else
:
# [2 (k and v), num_blocks, ...]
if
self
.
_use_flashinfer
:
# FlashInfer swaps 2<->num_blocks dimensions.
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
4
# [2, block_size, kv_heads, head_dim]
else
:
self
.
num_blocks
=
first_kv_cache
.
shape
[
1
]
block_rank
=
3
# [block_size, kv_heads, head_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
n_kv_heads
,
head_dim
=
block_shape
[
-
3
:]
# head size in bytes.
self
.
slot_size_bytes
=
kv_elem_size
*
n_kv_heads
*
head_dim
assert
block_size
==
self
.
block_size
else
:
raise
RuntimeError
(
f
"
{
self
.
device_type
}
(
{
self
.
backend_name
}
) is not supported."
)
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self
.
block_len
=
kv_elem_size
*
math
.
prod
(
block_shape
)
logger
.
info
(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
"per_layer_kv_cache_shape: %s"
,
use_mla
,
self
.
kv_buffer_device
,
self
.
use_host_buffer
,
self
.
num_blocks
,
block_shape
,
first_kv_cache
.
shape
)
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
self
.
device_kv_caches
=
kv_caches
kv_caches_base_addr
=
[]
"use_host_buffer: %s"
,
self
.
use_mla
,
self
.
kv_buffer_device
,
self
.
use_host_buffer
)
caches_data
=
[]
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses
=
[]
xfer_buffers
=
(
self
.
host_xfer_buffers
if
self
.
use_host_buffer
else
kv_caches
)
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
...
...
@@ -770,42 +717,35 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for
cache_or_caches
in
xfer_buffers
.
values
():
# Normalize to always be a list of caches
cache_list
=
[
cache_or_caches
]
if
use_mla
\
or
self
.
_use_pallas_v1
or
self
.
_use_flashinfer
\
else
cache_or_caches
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
_use_pallas_v1
or
self
.
_use_flashinfer
)
tensor_size_bytes
=
None
for
layer_name
,
cache_or_caches
in
xfer_buffers
.
items
():
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
region_len
=
self
.
num_blocks
*
self
.
block_len
# NOTE: use tp_rank for device_id since multi-node TP
# is rarely used.
caches_data
.
append
((
base_addr
,
region_len
,
self
.
tp_rank
,
""
))
kv_caches_base_addr
.
append
(
base_addr
)
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
kv_caches_base_addr
if
base_addr
in
seen_base_addresses
:
continue
seen_base_addresses
.
append
(
base_addr
)
curr_tensor_size_bytes
=
cache
.
numel
()
*
cache
.
element_size
()
if
tensor_size_bytes
is
None
:
tensor_size_bytes
=
curr_tensor_size_bytes
self
.
num_blocks
=
cache
.
shape
[
0
]
assert
tensor_size_bytes
==
curr_tensor_size_bytes
,
\
"All kv cache tensors must have the same size"
caches_data
.
append
(
(
base_addr
,
tensor_size_bytes
,
self
.
tp_rank
,
""
))
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
seen_base_addresses
self
.
num_regions
=
len
(
caches_data
)
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
vllm_config
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
vllm_config
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
for
layer_idx
in
range
(
self
.
num_layers
):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention
=
no_rope_layers
[
layer_idx
]
!=
0
block_window
=
chunk_block_size
if
is_local_attention
else
None
self
.
block_window_per_layer
.
append
(
block_window
)
logger
.
debug
(
"Llama 4 block window per layer mapping: %s"
,
self
.
block_window_per_layer
)
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
self
.
nixl_memory_type
)
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
...
...
@@ -813,9 +753,20 @@ class NixlConnectorWorker:
logger
.
debug
(
"Done registering descs"
)
self
.
_registered_descs
.
append
(
descs
)
assert
tensor_size_bytes
is
not
None
assert
self
.
num_blocks
!=
0
assert
tensor_size_bytes
%
self
.
num_blocks
==
0
self
.
block_len
=
tensor_size_bytes
//
self
.
num_blocks
self
.
slot_size_bytes
=
self
.
block_len
//
self
.
block_size
if
self
.
_use_flashinfer
:
assert
self
.
slot_size_bytes
%
2
==
0
self
.
slot_size_bytes
/=
2
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
# Register local/src descr for NIXL xfer.
blocks_data
=
[]
for
base_addr
in
se
lf
.
kv_caches_base_addr
[
self
.
engine_id
]
:
for
base_addr
in
se
en_base_addresses
:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
...
...
@@ -836,6 +787,26 @@ class NixlConnectorWorker:
self
.
src_xfer_side_handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# models with local attention (Llama 4). Can remove this once enabled.
if
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
vllm_config
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
vllm_config
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
for
layer_idx
in
range
(
self
.
num_layers
):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention
=
no_rope_layers
[
layer_idx
]
!=
0
block_window
=
chunk_block_size
if
is_local_attention
else
None
self
.
block_window_per_layer
.
append
(
block_window
)
logger
.
debug
(
"Llama 4 block window per layer mapping: %s"
,
self
.
block_window_per_layer
)
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
# After KV Caches registered, listen for new connections.
metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment