Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
098d8447
Unverified
Commit
098d8447
authored
Mar 11, 2026
by
Nicolò Lucchesi
Committed by
GitHub
Mar 11, 2026
Browse files
[NIXL][1/N] Refactor `kernel_block_size` detection (#35752)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
a40ee486
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
126 additions
and
96 deletions
+126
-96
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+57
-23
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+4
-16
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+35
-25
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+25
-27
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+5
-5
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
098d8447
...
...
@@ -9,7 +9,7 @@ import textwrap
import
time
import
uuid
from
collections
import
defaultdict
from
typing
import
Any
from
typing
import
Any
,
cast
from
unittest.mock
import
MagicMock
,
patch
import
msgspec
...
...
@@ -332,14 +332,22 @@ def test_kv_transfer_handshake(dist_init):
# Prefill connector will register KV cache to populate proper handshake
# metadata.
# TODO this must match with values used in kv cache config
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
kv_cache_spec
=
cast
(
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
)
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
kv_cache_config
.
num_blocks
,
block_size
=
kv_cache_spec
.
block_size
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
head_size
=
kv_cache_spec
.
head_size
,
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
...
...
@@ -383,7 +391,7 @@ def test_kv_transfer_handshake(dist_init):
# Decode connector will be able to create handshake with the prefill connector.
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_
kv_cache_config
(
block_size
=
16
)
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
decode_connector
.
register_kv_caches
(
kv_caches
)
...
...
@@ -525,11 +533,13 @@ class TestNixlHandshake:
request_id
=
"req_id"
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
,
kv_cache_config
=
kv_cache_config
,
)
assert
isinstance
(
connector
.
connector_worker
.
nixl_wrapper
,
FakeNixlWrapper
)
worker
=
connector
.
connector_worker
...
...
@@ -1479,18 +1489,22 @@ def test_register_kv_caches(
patch
(
f
"
{
nixl_module
}
.threading.Event"
),
patch
(
f
"
{
nixl_module
}
.threading.Thread"
)
as
mock_thread
,
patch
(
f
"
{
nixl_module
}
.get_current_attn_backend"
)
as
mock_get_attn_backend
,
patch
(
f
"
{
nixl_module
}
.get_current_attn_backends"
)
as
mock_get_attn_backends
,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend
.
return_value
=
backend_cls
mock_get_attn_backends
.
return_value
=
[
backend_cls
]
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
make_kv_cache_config
(
block_size
=
16
)
)
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
,
kv_cache_config
=
kv_cache_config
,
)
# Get the mock instance
...
...
@@ -1515,6 +1529,13 @@ def test_register_kv_caches(
num_layers
=
32
block_size
=
16
num_blocks
=
8
# Keep the fake worker's expected num_blocks in sync with the
# cross-layer tensor we are about to register.
worker_kv_cache_config
=
make_kv_cache_config
(
block_size
=
block_size
,
num_blocks
=
num_blocks
)
connector
.
connector_worker
.
kv_cache_config
=
worker_kv_cache_config
connector
.
connector_worker
.
num_blocks
=
worker_kv_cache_config
.
num_blocks
kv_cache_spec
=
AttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
4
,
...
...
@@ -1568,11 +1589,17 @@ def test_register_kv_caches(
else
:
# Create test kv cache tensors using proper backend shape
kv_cache_spec
=
cast
(
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
)
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
kv_cache_config
.
num_blocks
,
block_size
=
kv_cache_spec
.
block_size
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
head_size
=
kv_cache_spec
.
head_size
,
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
...
...
@@ -1606,7 +1633,7 @@ def test_register_kv_caches(
unique_tensor
[
1
].
data_ptr
(),
]
expected_num_entries
=
4
expected_blocks_count
=
8
expected_blocks_count
=
kv_cache_config
.
num_blocks
*
4
# Execute register_kv_caches
connector
.
register_kv_caches
(
kv_caches
)
...
...
@@ -1639,7 +1666,7 @@ def test_register_kv_caches(
num_blocks
=
8
expected_block_len
=
expected_tensor_size
//
num_blocks
else
:
num_blocks
=
2
num_blocks
=
kv_cache_config
.
num_blocks
if
is_blocks_first
:
expected_block_len
=
expected_tensor_size
//
num_blocks
//
2
else
:
...
...
@@ -2226,15 +2253,22 @@ def test_compatibility_hash_validation(
"enforce_handshake_compat"
:
enforce_handshake_compat
},
)
kv_cache_config
=
make_kv_cache_config
(
block_size
=
16
,
num_blocks
=
2
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
,
make_
kv_cache_config
(
block_size
=
16
)
local_vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
decode_worker
=
decode_connector
.
connector_worker
kv_cache_spec
=
cast
(
AttentionSpec
,
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
)
kv_cache_shape
=
decode_worker
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
num_blocks
=
kv_cache_config
.
num_blocks
,
block_size
=
kv_cache_spec
.
block_size
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
head_size
=
kv_cache_spec
.
head_size
,
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
kv_cache_spec
.
dtype
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
098d8447
...
...
@@ -38,7 +38,7 @@ from vllm.v1.kv_cache_interface import (
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.utils
import
AttentionGroup
,
select_common_block_size
from
vllm.v1.worker.utils
import
select_common_block_size
BLOCK_SIZE
=
16
NUM_BLOCKS
=
10
...
...
@@ -203,37 +203,25 @@ def _make_kv_cache_spec() -> FullAttentionSpec:
def
test_select_common_block_size_prefers_manager_block_size
():
backend_a
=
_make_mock_backend_for_kernel_block_size
([
MultipleOf
(
32
)])
backend_b
=
_make_mock_backend_for_kernel_block_size
([
64
,
MultipleOf
(
16
)])
attn_groups
=
[
AttentionGroup
(
backend_a
,
[],
[],
_make_kv_cache_spec
(),
0
),
AttentionGroup
(
backend_b
,
[],
[],
_make_kv_cache_spec
(),
0
),
]
selected_size
=
select_common_block_size
(
128
,
attn_groups
)
selected_size
=
select_common_block_size
(
128
,
[
backend_a
,
backend_b
]
)
assert
selected_size
==
128
def
test_select_common_block_size_uses_largest_shared_int
():
backend_a
=
_make_mock_backend_for_kernel_block_size
([
128
,
64
])
backend_b
=
_make_mock_backend_for_kernel_block_size
([
64
,
32
])
attn_groups
=
[
AttentionGroup
(
backend_a
,
[],
[],
_make_kv_cache_spec
(),
0
),
AttentionGroup
(
backend_b
,
[],
[],
_make_kv_cache_spec
(),
0
),
]
selected_size
=
select_common_block_size
(
256
,
attn_groups
)
selected_size
=
select_common_block_size
(
256
,
[
backend_a
,
backend_b
]
)
assert
selected_size
==
64
def
test_select_common_block_size_no_valid_option
():
backend_a
=
_make_mock_backend_for_kernel_block_size
([
64
])
backend_b
=
_make_mock_backend_for_kernel_block_size
([
MultipleOf
(
16
)])
attn_groups
=
[
AttentionGroup
(
backend_a
,
[],
[],
_make_kv_cache_spec
(),
0
),
AttentionGroup
(
backend_b
,
[],
[],
_make_kv_cache_spec
(),
0
),
]
with
pytest
.
raises
(
ValueError
):
select_common_block_size
(
48
,
attn_groups
)
select_common_block_size
(
48
,
[
backend_a
,
backend_b
]
)
def
test_update_states_new_request
(
model_runner
,
dist_init
):
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
098d8447
...
...
@@ -358,15 +358,6 @@ class TpKVTopology:
# stride_order to retrieve physical position of block_size
kv_cache_shape
=
tuple
(
kv_cache_shape
[
i
]
for
i
in
kv_cache_stride_order
)
# In the default non-cross layers layout the block_size position
# is logical while in the cross layers case it is the physical
# position. This matches the shape of the actual kv cache tensors
# passed at register_kv_caches()/register_cross_layers_kv_cache()
block_size_position
=
kv_cache_shape
.
index
(
_MOCK_BLOCK_SIZE
)
assert
block_size_position
is
not
None
self
.
_block_size_position
=
-
(
len
(
kv_cache_shape
)
-
block_size_position
)
@
property
def
is_kv_layout_blocks_first
(
self
)
->
bool
:
return
self
.
_is_kv_layout_blocks_first
...
...
@@ -390,10 +381,6 @@ class TpKVTopology:
def
cross_layers_blocks
(
self
)
->
bool
:
return
self
.
_cross_layers_blocks
@
property
def
block_size_position
(
self
)
->
int
:
return
self
.
_block_size_position
def
tp_ratio
(
self
,
remote_tp_size
:
int
,
...
...
@@ -484,23 +471,46 @@ class TpKVTopology:
return
self
.
get_target_remote_ranks
(
remote_tp_size
)
def
get_current_attn_backend
(
vllm_config
:
VllmConfig
):
def
get_current_attn_backends
(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
)
->
list
[
type
[
AttentionBackend
]]:
"""Get all distinct attention backends for the given layers.
Args:
vllm_config: The current vLLM configuration.
layer_names: Optional list of layer names to scope the lookup.
When None, all attention layers are considered.
Returns:
Deduplicated list of attention backend classes.
"""
layer_type
=
cast
(
type
[
Any
],
AttentionLayerBase
)
layers
=
get_layers_from_vllm_config
(
vllm_config
,
layer_type
,
None
)
layers
=
get_layers_from_vllm_config
(
vllm_config
,
layer_type
,
layer_names
)
if
layers
:
backend
=
next
(
iter
(
layers
.
values
())).
get_attn_backend
()
else
:
# Fallback for tests, when static_forward_context is empty.
logger
.
debug
(
"No layers found in the vLLM config. "
"Falling back to default attention backend."
)
from
vllm.v1.attention.selector
import
get_attn_backend
seen
:
dict
[
str
,
type
[
AttentionBackend
]]
=
{}
for
layer
in
layers
.
values
():
backend
=
layer
.
get_attn_backend
()
seen
[
backend
.
full_cls_name
()]
=
backend
return
list
(
seen
.
values
())
# Fallback for tests, when static_forward_context is empty.
logger
.
debug
(
"No layers found in the vLLM config. Falling back to default attention backend."
)
from
vllm.v1.attention.selector
import
get_attn_backend
backend
=
get_attn_backend
(
return
[
get_attn_backend
(
head_size
=
vllm_config
.
model_config
.
get_head_size
(),
dtype
=
vllm_config
.
model_config
.
dtype
,
kv_cache_dtype
=
vllm_config
.
cache_config
.
cache_dtype
,
use_mla
=
vllm_config
.
model_config
.
use_mla
,
)
return
backend
]
def
get_current_attn_backend
(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
)
->
type
[
AttentionBackend
]:
"""Get the first attention backend for the given layers."""
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
098d8447
...
...
@@ -13,7 +13,7 @@ from collections import defaultdict
from
collections.abc
import
Iterator
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
cast
import
msgspec
import
numpy
as
np
...
...
@@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId
,
TpKVTopology
,
get_current_attn_backend
,
get_current_attn_backends
,
kv_postprocess_blksize_and_layout_on_receive
,
kv_postprocess_blksize_on_receive
,
kv_postprocess_layout_on_receive
,
...
...
@@ -61,6 +62,7 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
MambaSpec
,
SlidingWindowSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.utils
import
select_common_block_size
if
TYPE_CHECKING
:
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
...
...
@@ -945,7 +947,8 @@ class NixlConnectorWorker:
# Config.
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
# mypy will complain on re-assignment otherwise.
self
.
block_size
:
int
=
cast
(
int
,
vllm_config
.
cache_config
.
block_size
)
if
vllm_config
.
kv_transfer_config
is
None
:
raise
ValueError
(
"kv_transfer_config must be set for NixlConnector"
)
...
...
@@ -993,7 +996,7 @@ class NixlConnectorWorker:
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_group
=
get_tp_group
()
self
.
num_blocks
=
0
self
.
num_blocks
=
kv_cache_config
.
num_blocks
self
.
enable_permute_local_kv
=
False
# KV Caches and nixl tracking data.
...
...
@@ -1131,11 +1134,30 @@ class NixlConnectorWorker:
self
.
xfer_stats
=
NixlKVConnectorStats
()
self
.
_physical_blocks_per_logical_kv_block
=
1
self
.
_sync_block_size_with_kernel
()
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
"enforce_handshake_compat"
,
True
)
def
_sync_block_size_with_kernel
(
self
)
->
None
:
backends
=
get_current_attn_backends
(
self
.
vllm_config
)
kernel_block_size
=
select_common_block_size
(
self
.
block_size
,
backends
)
if
self
.
block_size
!=
kernel_block_size
:
logger
.
info_once
(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter."
,
self
.
block_size
,
kernel_block_size
,
)
assert
self
.
block_size
>
kernel_block_size
self
.
_physical_blocks_per_logical_kv_block
=
(
self
.
block_size
//
kernel_block_size
)
self
.
block_size
=
kernel_block_size
self
.
_block_size
[
self
.
engine_id
]
=
kernel_block_size
self
.
num_blocks
*=
self
.
_physical_blocks_per_logical_kv_block
def
_nixl_handshake
(
self
,
host
:
str
,
...
...
@@ -1469,7 +1491,6 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used.
self
.
block_len_per_layer
=
list
[
int
]()
self
.
slot_size_per_layer
=
list
[
int
]()
# HD bytes in kv terms
for
layer_name
,
cache_or_caches
in
xfer_buffers
.
items
():
cache_list
=
(
cache_or_caches
if
self
.
kv_topo
.
split_k_and_v
else
[
cache_or_caches
]
...
...
@@ -1486,26 +1507,11 @@ class NixlConnectorWorker:
logger
.
debug
(
"Registering layer %s with cache shape: %s"
,
layer_name
,
cache
.
shape
)
kernel_block_size
=
cache
.
shape
[
self
.
kv_topo
.
block_size_position
]
if
self
.
block_size
!=
kernel_block_size
:
logger
.
info_once
(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. "
,
self
.
block_size
,
kernel_block_size
,
)
self
.
_physical_blocks_per_logical_kv_block
=
(
self
.
block_size
//
kernel_block_size
)
self
.
block_size
=
kernel_block_size
self
.
_block_size
[
self
.
engine_id
]
=
kernel_block_size
seen_base_addresses
.
append
(
base_addr
)
curr_tensor_size_bytes
=
cache
.
numel
()
*
cache
.
element_size
()
if
tensor_size_bytes
is
None
:
tensor_size_bytes
=
curr_tensor_size_bytes
self
.
num_blocks
=
cache
.
shape
[
0
]
assert
cache
.
shape
[
0
]
==
self
.
num_blocks
,
(
"All kv cache tensors must have the same number of blocks"
...
...
@@ -1514,9 +1520,6 @@ class NixlConnectorWorker:
self
.
block_len_per_layer
.
append
(
curr_tensor_size_bytes
//
self
.
num_blocks
)
self
.
slot_size_per_layer
.
append
(
self
.
block_len_per_layer
[
-
1
]
//
self
.
block_size
)
if
not
self
.
use_mla
:
# Different kv cache shape is not supported by HeteroTP
...
...
@@ -1534,7 +1537,6 @@ class NixlConnectorWorker:
"Different block lengths collected: %s"
,
set
(
self
.
block_len_per_layer
)
)
assert
len
(
self
.
block_len_per_layer
)
==
len
(
seen_base_addresses
)
assert
self
.
num_blocks
!=
0
self
.
kv_caches_base_addr
[
self
.
engine_id
][
self
.
tp_rank
]
=
seen_base_addresses
self
.
num_regions
=
len
(
caches_data
)
...
...
@@ -1550,10 +1552,6 @@ class NixlConnectorWorker:
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
for
i
in
range
(
len
(
self
.
slot_size_per_layer
)):
assert
self
.
slot_size_per_layer
[
i
]
%
2
==
0
self
.
slot_size_per_layer
[
i
]
//=
2
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
...
...
vllm/v1/worker/utils.py
View file @
098d8447
...
...
@@ -258,7 +258,8 @@ class AttentionGroup:
def
select_common_block_size
(
kv_manager_block_size
:
int
,
attn_groups
:
list
[
AttentionGroup
]
kv_manager_block_size
:
int
,
backends
:
list
[
type
[
AttentionBackend
]],
)
->
int
:
"""
Select a block size that is supported by all backends and is a factor of
...
...
@@ -269,7 +270,7 @@ def select_common_block_size(
Args:
kv_manager_block_size: Block size of KV cache.
attn_group
s: List of attention
group
s.
backend
s: List of attention
backend classe
s.
Returns:
The selected block size.
...
...
@@ -297,8 +298,6 @@ def select_common_block_size(
return
False
return
True
backends
=
[
group
.
backend
for
group
in
attn_groups
]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly.
if
block_size_is_supported
(
backends
,
kv_manager_block_size
):
...
...
@@ -356,8 +355,9 @@ def prepare_kernel_block_sizes(
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
# This is an attention backend that supports virtual block splitting.
kv_manager_block_size
=
kv_cache_group
.
kv_cache_spec
.
block_size
group_backends
=
[
g
.
backend
for
g
in
attn_groups
[
kv_cache_gid
]]
selected_kernel_size
=
select_common_block_size
(
kv_manager_block_size
,
attn_groups
[
kv_cache_gid
]
kv_manager_block_size
,
group_backends
)
kernel_block_sizes
.
append
(
selected_kernel_size
)
elif
isinstance
(
kv_cache_spec
,
MambaSpec
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment