Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a01ef3fa
Unverified
Commit
a01ef3fa
authored
Feb 01, 2026
by
Yifan Qiao
Committed by
GitHub
Feb 02, 2026
Browse files
[Fix] prefix cache hit rate == 0 bug with gpt-oss style models (#33524)
Signed-off-by:
Yifan Qiao
<
yifanqiao@berkeley.edu
>
parent
7320ca39
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
366 additions
and
47 deletions
+366
-47
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+343
-40
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+23
-7
No files found.
tests/v1/core/test_prefix_caching.py
View file @
a01ef3fa
...
@@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
...
@@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
def
make_kv_cache_config_hybrid_model
(
def
make_kv_cache_config_hybrid_model
(
block_size
:
int
,
num_blocks
:
int
,
second_spec_type
:
str
=
"sliding_window"
block_size
:
int
,
num_blocks
:
int
,
sliding_window_blocks
:
int
,
second_spec_type
:
str
=
"sliding_window"
,
)
->
KVCacheConfig
:
)
->
KVCacheConfig
:
if
second_spec_type
==
"sliding_window"
:
if
second_spec_type
==
"sliding_window"
:
second_spec
=
SlidingWindowSpec
(
second_spec
=
SlidingWindowSpec
(
...
@@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model(
...
@@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model(
num_kv_heads
=
1
,
num_kv_heads
=
1
,
head_size
=
1
,
head_size
=
1
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
sliding_window
=
2
*
block_size
,
sliding_window
=
sliding_window_blocks
*
block_size
,
)
)
elif
second_spec_type
==
"mamba"
:
elif
second_spec_type
==
"mamba"
:
second_spec
=
MambaSpec
(
second_spec
=
MambaSpec
(
...
@@ -325,7 +328,7 @@ def test_prefill(hash_fn):
...
@@ -325,7 +328,7 @@ def test_prefill(hash_fn):
def
test_prefill_hybrid_model
():
def
test_prefill_hybrid_model
():
block_size
=
16
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config_hybrid_model
(
block_size
,
21
),
make_kv_cache_config_hybrid_model
(
block_size
,
21
,
2
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
hash_block_size
=
block_size
,
hash_block_size
=
block_size
,
...
@@ -334,7 +337,8 @@ def test_prefill_hybrid_model():
...
@@ -334,7 +337,8 @@ def test_prefill_hybrid_model():
hash_fn
=
sha256
hash_fn
=
sha256
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
num_full_blocks
=
3
common_token_ids
=
[
i
for
i
in
range
(
num_full_blocks
)
for
_
in
range
(
block_size
)]
# Fully cache miss
# Fully cache miss
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
...
@@ -375,6 +379,7 @@ def test_prefill_hybrid_model():
...
@@ -375,6 +379,7 @@ def test_prefill_hybrid_model():
# Cache hit in the common prefix
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
unique_token_ids
=
[
3
]
*
5
all_token_ids
=
common_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
block_hashes
)
==
3
assert
len
(
req1
.
block_hashes
)
==
3
...
@@ -394,34 +399,13 @@ def test_prefill_hybrid_model():
...
@@ -394,34 +399,13 @@ def test_prefill_hybrid_model():
manager
.
free
(
req0
)
manager
.
free
(
req0
)
manager
.
free
(
req1
)
manager
.
free
(
req1
)
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
)
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
,
):
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
block_size
,
sha256
)
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
.
pop
(
hash_with_group_id
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
len
(
req
.
block_hashes
)
==
3
assert
num_computed_tokens
==
expect_hit_length
*
block_size
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
num_computed_tokens
//
block_size
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
[
hash_with_group_id
]
=
(
cached_block_hash_to_block_bak
[
hash_with_group_id
]
)
manager
.
free
(
req
)
# Evict the blocks outside sliding window, does not affect the hit length.
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit
(
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"2"
,
"2"
,
all_token_ids
,
[
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
2
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
2
),
...
@@ -430,13 +414,23 @@ def test_prefill_hybrid_model():
...
@@ -430,13 +414,23 @@ def test_prefill_hybrid_model():
)
)
# Evict the first block of full attention, makes total cache miss.
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit
(
_test_partial_request_hit
(
"3"
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
manager
,
block_size
,
num_full_blocks
,
"3"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
,
)
)
# Evict the last block of all layers, reduces the hit length to 2.
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit
(
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"4"
,
"4"
,
all_token_ids
,
[
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
),
...
@@ -446,18 +440,36 @@ def test_prefill_hybrid_model():
...
@@ -446,18 +440,36 @@ def test_prefill_hybrid_model():
)
)
# Evict the last block of full attention, reduces the hit length to 2.
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit
(
_test_partial_request_hit
(
"5"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
)],
2
manager
,
block_size
,
num_full_blocks
,
"5"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
)],
2
,
)
)
# Evict the last block of sliding window, reduces the hit length to 2.
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
_test_partial_request_hit
(
"6"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
)],
2
manager
,
block_size
,
num_full_blocks
,
"6"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
)],
2
,
)
)
# Evict the last block of sliding window, reduces the hit length to 2.
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
_test_partial_request_hit
(
"7"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
2
)],
2
manager
,
block_size
,
num_full_blocks
,
"7"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
2
)],
2
,
)
)
# Evict different set of blocks for full attention and sliding window makes
# Evict different set of blocks for full attention and sliding window makes
...
@@ -466,8 +478,12 @@ def test_prefill_hybrid_model():
...
@@ -466,8 +478,12 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers
# Then it is cache miss as the two type of layers
# have different hit length.
# have different hit length.
test_partial_request_hit
(
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"8"
,
"8"
,
all_token_ids
,
[
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
...
@@ -477,6 +493,214 @@ def test_prefill_hybrid_model():
...
@@ -477,6 +493,214 @@ def test_prefill_hybrid_model():
)
)
def
test_prefill_hybrid_model_eagle
():
block_size
=
16
kv_cache_config
=
make_kv_cache_config_hybrid_model
(
block_size
,
31
,
3
)
manager
=
KVCacheManager
(
kv_cache_config
,
max_model_len
=
8192
,
enable_caching
=
True
,
hash_block_size
=
block_size
,
use_eagle
=
True
,
)
hash_fn
=
sha256
# Complete 6 blocks (96 tokens)
num_full_blocks
=
6
common_token_ids
=
[
i
for
i
in
range
(
num_full_blocks
)
for
_
in
range
(
block_size
)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
6
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
block_hashes
)
==
len
(
all_token_ids
)
//
block_size
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
len
(
all_token_ids
),
num_computed_tokens
,
computed_blocks
)
block_ids
=
(
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
,
12
,
13
,
14
],
[
15
,
16
,
17
,
18
,
19
,
20
,
21
],
)
assert
blocks
is
not
None
and
blocks
.
get_block_ids
()
==
block_ids
# Check full block metadata
parent_block_hash
=
None
for
i
,
full_block_ids
in
enumerate
(
zip
(
*
(
row
[:
-
1
]
for
row
in
block_ids
))):
block_tokens
=
tuple
(
all_token_ids
[
i
*
block_size
:
(
i
+
1
)
*
block_size
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
for
group_id
,
block_id
in
enumerate
(
full_block_ids
):
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
group_id
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
# Check partial block metadata
for
partial_block_id
in
(
row
[
-
1
]
for
row
in
block_ids
):
assert
manager
.
block_pool
.
blocks
[
partial_block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
partial_block_id
].
ref_cnt
==
1
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
6
]
*
5
all_token_ids
=
common_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
block_hashes
)
==
num_full_blocks
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
],
[
0
,
9
,
10
,
11
],
[
0
,
16
,
17
,
18
],
)
assert
num_computed_tokens
==
4
*
block_size
num_new_tokens
=
len
(
all_token_ids
)
-
num_computed_tokens
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
num_computed_tokens
,
computed_blocks
)
assert
blocks
is
not
None
and
blocks
.
get_block_ids
()
==
(
[
22
,
23
,
24
],
[
25
,
26
,
27
],
[
28
,
29
,
30
],
)
for
block_per_group
in
computed_blocks
.
blocks
:
for
block
in
block_per_group
:
if
block
!=
manager
.
block_pool
.
null_block
:
assert
block
.
ref_cnt
==
2
block_hashes
=
req1
.
block_hashes
manager
.
free
(
req0
)
manager
.
free
(
req1
)
# Evict the blocks outside sliding window, does not affect the hit length.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"2"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
2
),
],
4
,
)
# Evict the first block of full attention, makes total cache miss.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"3"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
,
)
# Evict the last block of all layers, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"4"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
2
),
],
3
,
)
# Evict the last block of full attention, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"5"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
0
)],
3
,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"6"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
2
],
1
)],
3
,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"7"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
2
],
2
)],
3
,
)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# The cache hit length of full attention is 4 * block_size.
# The cache hit length of sliding window is 3 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"8"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
2
),
],
0
,
)
def
_test_partial_request_hit
(
manager
:
KVCacheManager
,
block_size
:
int
,
num_full_blocks
,
request_id
:
str
,
prompt_token_ids
:
list
[
int
],
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
,
):
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
)
req
=
make_request
(
request_id
,
prompt_token_ids
,
block_size
,
sha256
)
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
.
pop
(
hash_with_group_id
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
len
(
req
.
block_hashes
)
==
num_full_blocks
assert
num_computed_tokens
==
expect_hit_length
*
block_size
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
num_computed_tokens
//
block_size
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
[
hash_with_group_id
]
=
(
cached_block_hash_to_block_bak
[
hash_with_group_id
]
)
manager
.
free
(
req
)
def
_make_hybrid_kv_cache_config
(
def
_make_hybrid_kv_cache_config
(
block_size
:
int
,
num_blocks
:
int
,
spec_types
:
list
[
str
]
block_size
:
int
,
num_blocks
:
int
,
spec_types
:
list
[
str
]
)
->
KVCacheConfig
:
)
->
KVCacheConfig
:
...
@@ -655,6 +879,85 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
...
@@ -655,6 +879,85 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
manager
.
free
(
req1
)
manager
.
free
(
req1
)
# Test cases with eagle enabled: Only test a single simple case for now.
# - 2 groups: 1 full + 1 other
_EAGLE_HYBRID_MODEL_TEST_CASES
=
[
# 2 groups: 1 full + 1 other
pytest
.
param
([
"full"
,
"sliding_window"
],
2
,
id
=
"2g-full+sw"
),
]
@
pytest
.
mark
.
parametrize
(
"spec_types,expect_hit_length"
,
_EAGLE_HYBRID_MODEL_TEST_CASES
)
def
test_prefill_hybrid_model_combinations_eagle
(
spec_types
:
list
[
str
],
expect_hit_length
:
int
):
"""
Test prefix caching with hybrid models (1 full attn + 1 other) with EAGLE.
More complex hybrid models with EAGLE are not yet supported (see issue #32802).
"""
block_size
=
16
num_groups
=
len
(
spec_types
)
# Allocate enough blocks for all groups
num_blocks
=
10
*
num_groups
kv_cache_config
=
_make_hybrid_kv_cache_config
(
block_size
,
num_blocks
,
spec_types
)
manager
=
KVCacheManager
(
kv_cache_config
,
max_model_len
=
8192
,
enable_caching
=
True
,
hash_block_size
=
block_size
,
use_eagle
=
True
,
)
hash_fn
=
sha256
# Complete 3 blocks (48 tokens)
num_full_blocks
=
4
common_token_ids
=
[
i
for
i
in
range
(
num_full_blocks
)
for
_
in
range
(
block_size
)]
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
# First request: no cache hit initially
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
block_hashes
)
==
num_full_blocks
assert
not
computed_blocks
.
blocks
[
0
]
# No cache hit initially
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
len
(
all_token_ids
),
num_computed_tokens
,
computed_blocks
)
assert
blocks
is
not
None
# Should have blocks for all groups
assert
len
(
blocks
.
get_block_ids
())
==
num_groups
# Second request: should hit cached blocks for common prefix
all_token_ids
=
common_token_ids
+
[
6
]
*
5
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
# Should hit cached blocks for all groups
assert
num_computed_tokens
==
expect_hit_length
*
block_size
assert
len
(
computed_blocks
.
blocks
)
==
num_groups
# Verify each group has the correct number of computed blocks
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
expect_hit_length
# Allocate and verify blocks for second request
blocks
=
manager
.
allocate_slots
(
req1
,
len
(
all_token_ids
)
-
num_computed_tokens
,
num_computed_tokens
,
computed_blocks
,
)
assert
blocks
is
not
None
assert
len
(
blocks
.
get_block_ids
())
==
num_groups
manager
.
free
(
req0
)
manager
.
free
(
req1
)
def
test_prefill_plp
():
def
test_prefill_plp
():
"""Test prefill with APC and some prompt logprobs (plp) requests.
"""Test prefill with APC and some prompt logprobs (plp) requests.
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
a01ef3fa
...
@@ -479,6 +479,16 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
...
@@ -479,6 +479,16 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
hit_length
=
max_cache_hit_length
hit_length
=
max_cache_hit_length
hit_blocks_by_group
:
list
[
list
[
KVCacheBlock
]
|
None
]
=
[
None
]
*
num_groups
hit_blocks_by_group
:
list
[
list
[
KVCacheBlock
]
|
None
]
=
[
None
]
*
num_groups
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
# Full attn is always first if it exists. This avoids EAGLE drops
# being applied multiple times to non-full-attn groups.
# FIXME (yifan): However, for complex hybrid models with multiple attn
# groups, we still have the EAGLE spiral block dropping problem. See
# discussion in issue https://github.com/vllm-project/vllm/issues/32802.
is_simple_hybrid
=
len
(
self
.
attention_groups
)
==
2
and
isinstance
(
self
.
attention_groups
[
0
][
0
],
FullAttentionSpec
)
while
True
:
while
True
:
curr_hit_length
=
hit_length
curr_hit_length
=
hit_length
...
@@ -495,10 +505,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
...
@@ -495,10 +505,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
# the last iteration.
# the last iteration.
num_blocks
=
curr_hit_length
//
spec
.
block_size
num_blocks
=
curr_hit_length
//
spec
.
block_size
curr_hit_length
=
num_blocks
*
spec
.
block_size
curr_hit_length
=
num_blocks
*
spec
.
block_size
for
group_id
in
group_ids
:
blocks
=
hit_blocks_by_group
[
group_id
]
assert
blocks
is
not
None
del
blocks
[
num_blocks
:]
else
:
else
:
hit_blocks
=
manager_cls
.
find_longest_cache_hit
(
hit_blocks
=
manager_cls
.
find_longest_cache_hit
(
block_hashes
=
_get_block_hashes
(
spec
),
block_hashes
=
_get_block_hashes
(
spec
),
...
@@ -513,10 +519,20 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
...
@@ -513,10 +519,20 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
for
group_id
,
blocks
in
zip
(
group_ids
,
hit_blocks
):
for
group_id
,
blocks
in
zip
(
group_ids
,
hit_blocks
):
hit_blocks_by_group
[
group_id
]
=
blocks
hit_blocks_by_group
[
group_id
]
=
blocks
if
curr_hit_length
<
hit_length
:
if
curr_hit_length
>=
hit_length
:
hit_length
=
curr_hit_length
else
:
break
break
hit_length
=
curr_hit_length
# Simple hybrid: exit after one iteration
if
is_simple_hybrid
:
break
# Truncate full attention blocks to final hit_length (if present)
spec
,
group_ids
,
_
=
self
.
attention_groups
[
0
]
if
isinstance
(
spec
,
FullAttentionSpec
):
num_blocks
=
hit_length
//
spec
.
block_size
for
group_id
in
group_ids
:
if
(
blks
:
=
hit_blocks_by_group
[
group_id
])
is
not
None
:
del
blks
[
num_blocks
:]
return
tuple
(
return
tuple
(
blocks
if
blocks
is
not
None
else
[]
for
blocks
in
hit_blocks_by_group
blocks
if
blocks
is
not
None
else
[]
for
blocks
in
hit_blocks_by_group
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment