Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
755fa8b6
Unverified
Commit
755fa8b6
authored
Jul 29, 2025
by
Chen Zhang
Committed by
GitHub
Jul 29, 2025
Browse files
[KVCache] Make KVCacheSpec hashable (#21791)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
24704191
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
88 deletions
+100
-88
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+33
-1
tests/v1/e2e/test_correctness_sliding_window.py
tests/v1/e2e/test_correctness_sliding_window.py
+6
-2
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+14
-17
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+18
-17
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+29
-51
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
755fa8b6
...
...
@@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_utils import (
estimate_max_model_len
,
generate_block_hash_extra_keys
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
hash_block_tokens
,
hash_request_tokens
,
init_none_hash
,
unify_kv_cache_configs
)
is_kv_cache_type_uniform
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
...
...
@@ -685,6 +685,38 @@ def test_merge_kv_cache_spec():
assert
merged_layer_spec
.
sliding_window
==
1
def
test_is_kv_cache_type_uniform
():
kv_cache_spec
=
{
"layer_1"
:
new_kv_cache_spec
(
num_kv_heads
=
32
),
"layer_2"
:
new_kv_cache_spec
(
num_kv_heads
=
32
),
}
assert
is_kv_cache_type_uniform
(
kv_cache_spec
)
kv_cache_spec
=
{
"layer_1"
:
new_kv_cache_spec
(
num_kv_heads
=
32
),
"layer_2"
:
new_kv_cache_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
}
assert
is_kv_cache_type_uniform
(
kv_cache_spec
)
kv_cache_spec
=
{
"layer_1"
:
new_kv_cache_spec
(
num_kv_heads
=
32
),
"layer_2"
:
new_sliding_window_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
}
assert
not
is_kv_cache_type_uniform
(
kv_cache_spec
)
kv_cache_spec
=
{
"layer_1"
:
new_sliding_window_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
"layer_2"
:
new_sliding_window_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
}
assert
is_kv_cache_type_uniform
(
kv_cache_spec
)
kv_cache_spec
=
{
"layer_1"
:
new_sliding_window_spec
(
num_kv_heads
=
32
,
sliding_window
=
1
),
"layer_2"
:
new_sliding_window_spec
(
num_kv_heads
=
32
,
sliding_window
=
2
),
}
assert
not
is_kv_cache_type_uniform
(
kv_cache_spec
)
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"max_model_len"
,
"want_estimated_max_len"
),
[
(
"Qwen/Qwen1.5-7B"
,
16385
,
16384
),
...
...
tests/v1/e2e/test_correctness_sliding_window.py
View file @
755fa8b6
...
...
@@ -30,7 +30,9 @@ model_config = {
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_sliding_window_retrieval
(
monkeypatch
,
model
,
batch_size
,
seed
):
@
pytest
.
mark
.
parametrize
(
"disable_hybrid_kv_cache_manager"
,
[
True
,
False
])
def
test_sliding_window_retrieval
(
monkeypatch
,
model
,
batch_size
,
seed
,
disable_hybrid_kv_cache_manager
):
"""
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
asks for value of one of them (which is outside the sliding window).
...
...
@@ -42,7 +44,9 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
test_config
=
model_config
[
model
]
llm
=
LLM
(
model
=
model
)
llm
=
LLM
(
model
=
model
,
disable_hybrid_kv_cache_manager
=
disable_hybrid_kv_cache_manager
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
prompts
,
answer
,
indices
=
prep_prompts
(
batch_size
,
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
755fa8b6
...
...
@@ -7,7 +7,8 @@ from vllm.v1.core.block_pool import BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
,
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.request
import
Request
...
...
@@ -258,44 +259,40 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_
type_id
:
Optional
[
str
]
=
None
other_
type_id
:
Optional
[
str
]
=
None
full_attention_
spec
:
Optional
[
FullAttentionSpec
]
=
None
other_
spec
:
Optional
[
KVCacheSpec
]
=
None
self
.
full_attention_group_ids
:
list
[
int
]
=
[]
self
.
other_group_ids
:
list
[
int
]
=
[]
for
i
,
g
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
if
isinstance
(
g
.
kv_cache_spec
,
FullAttentionSpec
):
if
full_attention_
type_id
is
None
:
full_attention_
type_id
=
g
.
kv_cache_spec
.
type_id
if
full_attention_
spec
is
None
:
full_attention_
spec
=
g
.
kv_cache_spec
else
:
assert
full_attention_
type_id
==
g
.
kv_cache_spec
.
type_id
,
(
assert
full_attention_
spec
==
g
.
kv_cache_spec
,
(
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now."
)
self
.
full_attention_group_ids
.
append
(
i
)
else
:
if
other_
type_id
is
None
:
other_
type_id
=
g
.
kv_cache_spec
.
type_id
if
other_
spec
is
None
:
other_
spec
=
g
.
kv_cache_spec
else
:
assert
other_
type_id
==
g
.
kv_cache_spec
.
type_id
,
(
assert
other_
spec
==
g
.
kv_cache_spec
,
(
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now."
)
self
.
other_group_ids
.
append
(
i
)
assert
full_attention_
type_id
is
not
None
,
(
assert
full_attention_
spec
is
not
None
,
(
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now."
)
assert
other_
type_id
is
not
None
,
(
assert
other_
spec
is
not
None
,
(
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now."
)
self
.
full_attention_manager_cls
=
FullAttentionManager
self
.
other_attention_cls
=
self
.
single_type_managers
[
self
.
other_group_ids
[
0
]].
__class__
self
.
full_attention_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
self
.
full_attention_group_ids
[
0
]].
kv_cache_spec
self
.
other_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
self
.
other_group_ids
[
0
]].
kv_cache_spec
self
.
full_attention_spec
=
full_attention_spec
self
.
other_spec
=
other_spec
self
.
full_attention_block_size
=
self
.
full_attention_spec
.
block_size
self
.
other_block_size
=
self
.
other_spec
.
block_size
...
...
vllm/v1/core/kv_cache_utils.py
View file @
755fa8b6
...
...
@@ -5,7 +5,7 @@
import
os
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Iterable
,
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
astuple
,
dataclass
from
typing
import
Any
,
Callable
,
NamedTuple
,
Optional
from
vllm.config
import
VllmConfig
...
...
@@ -727,7 +727,9 @@ def create_kv_cache_group_specs(
def
is_kv_cache_type_uniform
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
])
->
bool
:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Whether all layers in the given KVCacheSpec have the same KV cache spec.
Note that we regard FullAttentionSpec with and without sliding window as
the same type.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
...
...
@@ -736,8 +738,12 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
True if all layers have the same type, False otherwise.
"""
layer_keys
=
set
(
layer
.
type_id
for
layer
in
kv_cache_spec
.
values
())
return
len
(
layer_keys
)
==
1
try
:
kv_cache_spec_values
=
list
(
kv_cache_spec
.
values
())
_
=
kv_cache_spec_values
[
0
].
merge
(
kv_cache_spec_values
)
except
AssertionError
:
return
False
return
True
def
get_max_concurrency_for_kv_cache_config
(
...
...
@@ -928,12 +934,12 @@ def _get_kv_cache_config_uniform_page_size(
Returns:
The generated KVCacheConfig
"""
# Group all layers by
type_id
.
# Group all layers by
kv_cache_spec
.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers
:
dict
[
str
,
list
[
str
]]
=
defaultdict
(
list
)
same_type_layers
:
dict
[
KVCacheSpec
,
list
[
str
]]
=
defaultdict
(
list
)
for
layer_name
,
layer_spec
in
kv_cache_spec
.
items
():
same_type_layers
[
layer_spec
.
type_id
].
append
(
layer_name
)
same_type_layers
[
layer_spec
].
append
(
layer_name
)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
...
...
@@ -1017,12 +1023,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
def
is_hybrid
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
])
->
bool
:
type_ids
=
set
(
layer_spec
.
type_id
for
layer_spec
in
kv_cache_spec
.
values
())
return
len
(
type_ids
)
>
1
if
not
is_hybrid
(
kv_cache_spec
):
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
return
logger
.
warning
(
...
...
@@ -1060,7 +1061,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
attention_chunk_size
=
spec
.
attention_chunk_size
,
)
if
is_hybrid
(
kv_cache_spec
):
if
not
is_kv_cache_type_uniform
(
kv_cache_spec
):
raise
ValueError
(
"Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type."
)
...
...
@@ -1119,11 +1120,11 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
in-place modified to make them consistent.
"""
# Sort the kv cache groups by
the type_id of
their KV cache spec.
# Sort the kv cache groups by their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for
kv_cache_config
in
kv_cache_configs
:
kv_cache_config
.
kv_cache_groups
.
sort
(
key
=
lambda
x
:
x
.
kv_cache_spec
.
type_id
)
kv_cache_config
.
kv_cache_groups
.
sort
(
key
=
lambda
x
:
(
type
(
x
.
kv_cache_spec
).
__name__
,
astuple
(
x
.
kv_cache_spec
))
)
# Verify that the groups of each rank are the same.
for
kv_cache_config
in
kv_cache_configs
[
1
:]:
...
...
vllm/v1/kv_cache_interface.py
View file @
755fa8b6
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
fields
from
math
import
prod
from
typing
import
Optional
...
...
@@ -16,7 +16,7 @@ from vllm.utils import cdiv, get_dtype_size
logger
=
init_logger
(
__name__
)
@
dataclass
@
dataclass
(
frozen
=
True
)
class
KVCacheSpec
:
"""
A base class for specifying the KV cache format of one layer.
...
...
@@ -25,20 +25,6 @@ class KVCacheSpec:
# number of tokens in a block
block_size
:
int
@
property
def
type_id
(
self
)
->
str
:
"""
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
number of heads)
Returns:
The type identifier of this KV cache.
"""
raise
NotImplementedError
@
property
def
page_size_bytes
(
self
)
->
int
:
"""
...
...
@@ -63,13 +49,12 @@ class KVCacheSpec:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert
all
(
spec
.
type_id
==
specs
[
0
].
type_id
for
spec
in
specs
[
1
:]),
(
"All layers in the same KV cache group must share the same "
"type_id."
)
assert
all
(
spec
==
specs
[
0
]
for
spec
in
specs
[
1
:]),
(
"All layers in the same KV cache group must be the same."
)
return
copy
.
deepcopy
(
specs
[
0
])
@
dataclass
@
dataclass
(
frozen
=
True
)
class
AttentionSpec
(
KVCacheSpec
):
num_kv_heads
:
int
head_size
:
int
...
...
@@ -84,7 +69,7 @@ class AttentionSpec(KVCacheSpec):
*
get_dtype_size
(
self
.
dtype
)
@
dataclass
@
dataclass
(
frozen
=
True
)
class
FullAttentionSpec
(
AttentionSpec
):
sliding_window
:
Optional
[
int
]
=
None
attention_chunk_size
:
Optional
[
int
]
=
None
...
...
@@ -98,10 +83,6 @@ class FullAttentionSpec(AttentionSpec):
Default to None for not using sliding window attention.
"""
@
property
def
type_id
(
self
)
->
str
:
return
f
"full_attention_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
...
...
@@ -123,15 +104,28 @@ class FullAttentionSpec(AttentionSpec):
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec
=
super
().
merge
(
specs
)
assert
all
(
isinstance
(
spec
,
FullAttentionSpec
)
for
spec
in
specs
),
(
"All attention layers in the same KV cache group must be "
"FullAttentionSpec."
)
sliding_window
=
set
(
spec
.
sliding_window
for
spec
in
specs
if
spec
.
sliding_window
is
not
None
)
attention_chunk_size
=
set
(
spec
.
attention_chunk_size
for
spec
in
specs
if
spec
.
attention_chunk_size
is
not
None
)
merged_spec
.
sliding_window
=
cls
.
merge_window_sizes
(
sliding_window
)
merged_spec
.
attention_chunk_size
=
(
cls
.
merge_window_sizes
(
attention_chunk_size
))
merged_spec
=
cls
(
block_size
=
specs
[
0
].
block_size
,
num_kv_heads
=
specs
[
0
].
num_kv_heads
,
head_size
=
specs
[
0
].
head_size
,
dtype
=
specs
[
0
].
dtype
,
use_mla
=
specs
[
0
].
use_mla
,
sliding_window
=
cls
.
merge_window_sizes
(
sliding_window
),
attention_chunk_size
=
cls
.
merge_window_sizes
(
attention_chunk_size
),
)
for
spec
in
specs
:
for
f
in
fields
(
AttentionSpec
):
assert
getattr
(
spec
,
f
.
name
)
==
getattr
(
merged_spec
,
f
.
name
),
(
"All attention layers in the same KV cache group must have "
"the same attention spec."
)
assert
(
(
merged_spec
.
sliding_window
is
not
None
)
+
(
merged_spec
.
attention_chunk_size
is
not
None
)
<=
1
...
...
@@ -140,16 +134,10 @@ class FullAttentionSpec(AttentionSpec):
return
merged_spec
@
dataclass
@
dataclass
(
frozen
=
True
)
class
ChunkedLocalAttentionSpec
(
AttentionSpec
):
attention_chunk_size
:
int
@
property
def
type_id
(
self
)
->
str
:
return
(
f
"local_attention_
{
self
.
attention_chunk_size
}
_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
)
# noqa
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_num_batched_tokens
=
(
...
...
@@ -165,17 +153,13 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
return
cdiv
(
num_tokens
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
@
dataclass
(
frozen
=
True
)
class
SlidingWindowSpec
(
AttentionSpec
):
sliding_window
:
int
def
__post_init__
(
self
):
assert
not
self
.
use_mla
,
"MLA is not supported for sliding window"
@
property
def
type_id
(
self
)
->
str
:
return
f
"sliding_window_
{
self
.
sliding_window
}
_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
# noqa
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
max_num_batched_tokens
=
(
...
...
@@ -195,23 +179,17 @@ class SlidingWindowSpec(AttentionSpec):
return
(
cdiv
(
num_tokens
,
self
.
block_size
)
+
1
)
*
self
.
page_size_bytes
@
dataclass
@
dataclass
(
frozen
=
True
)
class
MambaSpec
(
KVCacheSpec
):
shapes
:
tuple
[
tuple
[
int
,
...],
...]
dtype
:
torch
.
dtype
page_size_padded
:
Optional
[
int
]
=
None
mamba_type
:
str
=
"mamba2"
def
__post_init__
(
self
):
self
.
num_elements
=
sum
(
prod
(
shape
)
for
shape
in
self
.
shapes
)
@
property
def
type_id
(
self
)
->
str
:
return
f
"mamba_
{
self
.
shapes
}
_
{
self
.
dtype
}
_
{
self
.
mamba_type
}
"
@
property
def
page_size_bytes
(
self
)
->
int
:
page_size
=
self
.
num_elements
*
get_dtype_size
(
self
.
dtype
)
num_elements
=
sum
(
prod
(
shape
)
for
shape
in
self
.
shapes
)
page_size
=
num_elements
*
get_dtype_size
(
self
.
dtype
)
if
self
.
page_size_padded
is
not
None
:
assert
self
.
page_size_padded
>=
page_size
return
self
.
page_size_padded
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment