Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53415653
Unverified
Commit
53415653
authored
Aug 21, 2025
by
Flora Feng
Committed by
GitHub
Aug 21, 2025
Browse files
[P/D][Nixl] Make kv cache register compatible with hybrid memory allocator (#23079)
Signed-off-by:
sfeng33
<
4florafeng@gmail.com
>
parent
17373dcd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
95 deletions
+150
-95
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+85
-1
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+63
-92
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
53415653
...
@@ -14,6 +14,7 @@ from unittest.mock import patch
...
@@ -14,6 +14,7 @@ from unittest.mock import patch
import
pytest
import
pytest
import
ray
import
ray
import
torch
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.config
import
KVTransferConfig
...
@@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
...
@@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker
)
NixlConnectorWorker
)
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
...
@@ -98,7 +100,6 @@ class FakeNixlWrapper:
...
@@ -98,7 +100,6 @@ class FakeNixlWrapper:
def
set_cycles_before_xfer_done
(
self
,
cycles
:
int
):
def
set_cycles_before_xfer_done
(
self
,
cycles
:
int
):
"""Set the number of cycles before a transfer is considered done."""
"""Set the number of cycles before a transfer is considered done."""
self
.
_cycles_before_xfer_done
=
cycles
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
@@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
...
@@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
sampling_params
)
sampling_params
)
# Request-0 times out and is cleared!
# Request-0 times out and is cleared!
assert
'0'
not
in
req_to_blocks
assert
'0'
not
in
req_to_blocks
def
test_register_kv_caches
(
dist_init
):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
This test verifies:
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
tensor metadata
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
block layout info
"""
vllm_config
=
create_vllm_config
()
# Create test kv cache tensors using proper backend shape
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
# Store tensor info for validation
expected_tensor_size
=
shared_tensor
[
0
].
element_size
(
)
*
shared_tensor
[
0
].
numel
()
expected_base_addrs
=
[
shared_tensor
[
0
].
data_ptr
(),
shared_tensor
[
1
].
data_ptr
(),
unique_tensor
[
0
].
data_ptr
(),
unique_tensor
[
1
].
data_ptr
()
]
with
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
)
as
mock_nixl_wrapper
,
\
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
),
\
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
):
# noqa: E501
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
# Get the mock instance
mock_wrapper_instance
=
mock_nixl_wrapper
.
return_value
connector
.
connector_worker
.
nixl_wrapper
=
mock_wrapper_instance
# Execute register_kv_caches
connector
.
register_kv_caches
(
kv_caches
)
# Verify get_reg_descs was called with caches_data
assert
mock_wrapper_instance
.
get_reg_descs
.
called
caches_data
,
_
=
mock_wrapper_instance
.
get_reg_descs
.
call_args
[
0
]
assert
len
(
caches_data
)
==
4
for
i
,
cache_entry
in
enumerate
(
caches_data
):
base_addr
,
size
,
_tp_rank
,
_
=
cache_entry
assert
size
==
expected_tensor_size
,
\
f
"Entry
{
i
}
: Expected tensor size
{
expected_tensor_size
}
, "
\
f
"got
{
size
}
"
assert
base_addr
==
expected_base_addrs
[
i
],
\
f
"Entry
{
i
}
: Expected base address
{
expected_base_addrs
[
i
]
}
, "
\
f
"got
{
base_addr
}
"
# Verify get_xfer_descs was called with blocks_data
assert
mock_wrapper_instance
.
get_xfer_descs
.
called
blocks_data
,
_
=
mock_wrapper_instance
.
get_xfer_descs
.
call_args
[
0
]
# Validate blocks_data structure and size
expected_blocks_count
=
8
assert
len
(
blocks_data
)
==
expected_blocks_count
,
\
f
"Expected
{
expected_blocks_count
}
blocks, "
\
f
"got
{
len
(
blocks_data
)
}
"
expected_block_len
=
expected_tensor_size
//
2
for
i
,
block_entry
in
enumerate
(
blocks_data
):
block_start_addr
,
block_len
,
tp_rank
=
block_entry
assert
block_len
==
expected_block_len
,
\
f
"Block entry
{
i
}
: Expected block len
{
expected_block_len
}
, "
\
f
"got
{
block_len
}
"
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
53415653
...
@@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
...
@@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches:
Args:
dictionary of layer names, kv cache
kv_caches:
dictionary of layer names, kv cache
"""
"""
return
return
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
53415653
...
@@ -686,9 +686,6 @@ class NixlConnectorWorker:
...
@@ -686,9 +686,6 @@ class NixlConnectorWorker:
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""Register the KV Cache data in nixl."""
"""Register the KV Cache data in nixl."""
_
,
first_kv_cache
=
next
(
iter
(
kv_caches
.
items
()))
kv_elem_size
=
first_kv_cache
.
element_size
()
if
self
.
use_host_buffer
:
if
self
.
use_host_buffer
:
self
.
initialize_host_xfer_buffer
(
kv_caches
=
kv_caches
)
self
.
initialize_host_xfer_buffer
(
kv_caches
=
kv_caches
)
assert
len
(
self
.
host_xfer_buffers
)
==
len
(
kv_caches
),
(
assert
len
(
self
.
host_xfer_buffers
)
==
len
(
kv_caches
),
(
...
@@ -701,66 +698,16 @@ class NixlConnectorWorker:
...
@@ -701,66 +698,16 @@ class NixlConnectorWorker:
"host_xfer_buffer should not be initialized when "
"host_xfer_buffer should not be initialized when "
f
"kv_buffer_device is
{
self
.
kv_buffer_device
}
"
)
f
"kv_buffer_device is
{
self
.
kv_buffer_device
}
"
)
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla
=
len
(
first_kv_cache
.
shape
)
==
3
if
self
.
device_type
==
"tpu"
:
assert
not
use_mla
,
f
"
{
self
.
kv_buffer_device
}
does not support MLA."
assert
self
.
_use_pallas_v1
,
f
"attn backend:
{
self
.
backend_name
}
"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
3
# [block_size, kv_heads, head_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
n_kv_heads_x_2
,
head_dim
=
block_shape
self
.
slot_size_bytes
=
kv_elem_size
*
n_kv_heads_x_2
*
head_dim
elif
self
.
device_type
==
"cuda"
:
assert
use_mla
==
self
.
use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if
use_mla
:
# MLA case.
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
2
# [block_size, latent_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
kv_latent_dim
=
block_shape
self
.
slot_size_bytes
=
kv_elem_size
*
kv_latent_dim
else
:
# [2 (k and v), num_blocks, ...]
if
self
.
_use_flashinfer
:
# FlashInfer swaps 2<->num_blocks dimensions.
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
4
# [2, block_size, kv_heads, head_dim]
else
:
self
.
num_blocks
=
first_kv_cache
.
shape
[
1
]
block_rank
=
3
# [block_size, kv_heads, head_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
n_kv_heads
,
head_dim
=
block_shape
[
-
3
:]
# head size in bytes.
self
.
slot_size_bytes
=
kv_elem_size
*
n_kv_heads
*
head_dim
assert
block_size
==
self
.
block_size
else
:
raise
RuntimeError
(
f
"
{
self
.
device_type
}
(
{
self
.
backend_name
}
) is not supported."
)
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self
.
block_len
=
kv_elem_size
*
math
.
prod
(
block_shape
)
logger
.
info
(
logger
.
info
(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
"use_host_buffer: %s"
,
self
.
use_mla
,
self
.
kv_buffer_device
,
"per_layer_kv_cache_shape: %s"
,
use_mla
,
self
.
kv_buffer_device
,
self
.
use_host_buffer
)
self
.
use_host_buffer
,
self
.
num_blocks
,
block_shape
,
first_kv_cache
.
shape
)
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
self
.
device_kv_caches
=
kv_caches
kv_caches_base_addr
=
[]
caches_data
=
[]
caches_data
=
[]
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses
=
[]
xfer_buffers
=
(
self
.
host_xfer_buffers
if
self
.
use_host_buffer
else
kv_caches
)
# Note(tms): I modified this from the original region setup code.
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# K and V are now in different regions. Advantage is that we can
...
@@ -770,42 +717,35 @@ class NixlConnectorWorker:
...
@@ -770,42 +717,35 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
# to better exploit the memory layout (ie num_blocks is the first dim).
for
cache_or_caches
in
xfer_buffers
.
values
():
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
_use_pallas_v1
# Normalize to always be a list of caches
or
self
.
_use_flashinfer
)
cache_list
=
[
cache_or_caches
]
if
use_mla
\
tensor_size_bytes
=
None
or
self
.
_use_pallas_v1
or
self
.
_use_flashinfer
\
for
layer_name
,
cache_or_caches
in
xfer_buffers
.
items
():
else
cache_or_caches
cache_list
=
cache_or_caches
if
split_k_and_v
else
[
cache_or_caches
]
for
cache
in
cache_list
:
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
base_addr
=
cache
.
data_ptr
()
region_len
=
self
.
num_blocks
*
self
.
block_len
if
base_addr
in
seen_base_addresses
:
# NOTE: use tp_rank for device_id since multi-node TP
continue
# is rarely used.
caches_data
.
append
((
base_addr
,
region_len
,
self
.
tp_rank
,
""
))
seen_base_addresses
.
append
(
base_addr
)
kv_caches_base_addr
.
append
(
base_addr
)
curr_tensor_size_bytes
=
cache
.
numel
()
*
cache
.
element_size
()
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
kv_caches_base_addr
if
tensor_size_bytes
is
None
:
tensor_size_bytes
=
curr_tensor_size_bytes
self
.
num_blocks
=
cache
.
shape
[
0
]
assert
tensor_size_bytes
==
curr_tensor_size_bytes
,
\
"All kv cache tensors must have the same size"
caches_data
.
append
(
(
base_addr
,
tensor_size_bytes
,
self
.
tp_rank
,
""
))
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
seen_base_addresses
self
.
num_regions
=
len
(
caches_data
)
self
.
num_regions
=
len
(
caches_data
)
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
vllm_config
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
vllm_config
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
for
layer_idx
in
range
(
self
.
num_layers
):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention
=
no_rope_layers
[
layer_idx
]
!=
0
block_window
=
chunk_block_size
if
is_local_attention
else
None
self
.
block_window_per_layer
.
append
(
block_window
)
logger
.
debug
(
"Llama 4 block window per layer mapping: %s"
,
self
.
block_window_per_layer
)
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
self
.
nixl_memory_type
)
self
.
nixl_memory_type
)
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
...
@@ -813,9 +753,20 @@ class NixlConnectorWorker:
...
@@ -813,9 +753,20 @@ class NixlConnectorWorker:
logger
.
debug
(
"Done registering descs"
)
logger
.
debug
(
"Done registering descs"
)
self
.
_registered_descs
.
append
(
descs
)
self
.
_registered_descs
.
append
(
descs
)
assert
tensor_size_bytes
is
not
None
assert
self
.
num_blocks
!=
0
assert
tensor_size_bytes
%
self
.
num_blocks
==
0
self
.
block_len
=
tensor_size_bytes
//
self
.
num_blocks
self
.
slot_size_bytes
=
self
.
block_len
//
self
.
block_size
if
self
.
_use_flashinfer
:
assert
self
.
slot_size_bytes
%
2
==
0
self
.
slot_size_bytes
/=
2
self
.
device_kv_caches
=
kv_caches
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
# Register local/src descr for NIXL xfer.
# Register local/src descr for NIXL xfer.
blocks_data
=
[]
blocks_data
=
[]
for
base_addr
in
se
lf
.
kv_caches_base_addr
[
self
.
engine_id
]
:
for
base_addr
in
se
en_base_addresses
:
# NOTE With heter-TP, more blocks are prepared than what are
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# could create fewer, but then _get_block_descs_ids needs to
...
@@ -836,6 +787,26 @@ class NixlConnectorWorker:
...
@@ -836,6 +787,26 @@ class NixlConnectorWorker:
self
.
src_xfer_side_handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
self
.
src_xfer_side_handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
"NIXL_INIT_AGENT"
,
descs
)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# models with local attention (Llama 4). Can remove this once enabled.
if
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
vllm_config
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
vllm_config
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
for
layer_idx
in
range
(
self
.
num_layers
):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention
=
no_rope_layers
[
layer_idx
]
!=
0
block_window
=
chunk_block_size
if
is_local_attention
else
None
self
.
block_window_per_layer
.
append
(
block_window
)
logger
.
debug
(
"Llama 4 block window per layer mapping: %s"
,
self
.
block_window_per_layer
)
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
# After KV Caches registered, listen for new connections.
# After KV Caches registered, listen for new connections.
metadata
=
NixlAgentMetadata
(
metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment