Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
269bf46d
Unverified
Commit
269bf46d
authored
Mar 19, 2026
by
tianshu-Michael-yu
Committed by
GitHub
Mar 20, 2026
Browse files
fix: disambiguate multimodal prefix cache keys (#36708)
Signed-off-by:
tianshu.yu
<
tianshuyu.formal@gmail.com
>
parent
e5a77a50
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
16 deletions
+29
-16
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+10
-8
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+12
-4
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+7
-4
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
269bf46d
...
@@ -447,12 +447,12 @@ def test_generate_block_hash_extra_keys():
...
@@ -447,12 +447,12 @@ def test_generate_block_hash_extra_keys():
# Test with no extra keys
# Test with no extra keys
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request
,
0
,
5
,
0
)
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request
,
0
,
5
,
0
)
assert
extra_keys
==
(
"hash1"
,)
assert
extra_keys
==
(
(
"hash1"
,
0
),
)
assert
next_mm_idx
==
1
assert
next_mm_idx
==
1
# Test with partial overlap
# Test with partial overlap
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request
,
3
,
8
,
0
)
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request
,
3
,
8
,
0
)
assert
extra_keys
==
(
"hash1"
,)
assert
extra_keys
==
(
(
"hash1"
,
-
3
),
)
assert
next_mm_idx
==
1
assert
next_mm_idx
==
1
# Test with no overlap
# Test with no overlap
...
@@ -462,7 +462,7 @@ def test_generate_block_hash_extra_keys():
...
@@ -462,7 +462,7 @@ def test_generate_block_hash_extra_keys():
# Test with multiple extra keys
# Test with multiple extra keys
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request
,
0
,
15
,
0
)
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request
,
0
,
15
,
0
)
assert
extra_keys
==
(
"hash1"
,
"hash2"
)
assert
extra_keys
==
(
(
"hash1"
,
0
),
(
"hash2"
,
10
)
)
assert
next_mm_idx
==
2
assert
next_mm_idx
==
2
...
@@ -513,7 +513,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
...
@@ -513,7 +513,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
# Test with no extra keys
# Test with no extra keys
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request_mm
,
0
,
5
,
0
)
extra_keys
,
next_mm_idx
=
generate_block_hash_extra_keys
(
request_mm
,
0
,
5
,
0
)
assert
extra_keys
==
(
"hash1"
,
"salt"
)
assert
extra_keys
==
(
(
"hash1"
,
0
),
"salt"
)
assert
next_mm_idx
==
1
assert
next_mm_idx
==
1
...
@@ -637,8 +637,10 @@ def test_request_block_hasher(hash_fn):
...
@@ -637,8 +637,10 @@ def test_request_block_hasher(hash_fn):
block_hashes
=
request
.
block_hashes
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
assert
len
(
block_hashes
)
==
2
assert
block_hashes
[
0
]
==
hash_fn
((
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
(
"hash1"
,)))
assert
block_hashes
[
0
]
==
hash_fn
(
assert
block_hashes
[
1
]
==
hash_fn
((
block_hashes
[
0
],
(
3
,
4
,
5
),
(
"hash2"
,)))
(
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
((
"hash1"
,
0
),))
)
assert
block_hashes
[
1
]
==
hash_fn
((
block_hashes
[
0
],
(
3
,
4
,
5
),
((
"hash2"
,
0
),)))
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
...
@@ -1973,7 +1975,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
...
@@ -1973,7 +1975,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
(
(
kv_cache_utils
.
NONE_HASH
,
kv_cache_utils
.
NONE_HASH
,
tuple
(
prompt_token_ids
[:
block_size
]),
tuple
(
prompt_token_ids
[:
block_size
]),
(
"hash1"
,
block1_embeds_hash
),
(
(
"hash1"
,
0
),
block1_embeds_hash
),
)
)
)
)
assert
block_hashes
[
0
]
==
expected_hash1
assert
block_hashes
[
0
]
==
expected_hash1
...
@@ -1985,7 +1987,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
...
@@ -1985,7 +1987,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
(
(
block_hashes
[
0
],
block_hashes
[
0
],
tuple
(
prompt_token_ids
[
block_size
:
num_tokens
]),
tuple
(
prompt_token_ids
[
block_size
:
num_tokens
]),
(
"hash2"
,
block2_embeds_hash
),
(
(
"hash2"
,
0
),
block2_embeds_hash
),
)
)
)
)
assert
block_hashes
[
1
]
==
expected_hash2
assert
block_hashes
[
1
]
==
expected_hash2
...
...
tests/v1/core/test_prefix_caching.py
View file @
269bf46d
...
@@ -1570,20 +1570,24 @@ def test_mm_prefix_caching():
...
@@ -1570,20 +1570,24 @@ def test_mm_prefix_caching():
block_hashes
=
req0
.
block_hashes
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
]
==
sha256
(
assert
block_hashes
[
0
]
==
sha256
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
all_token_ids
[:
block_size
]),
(
"aaa"
,))
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
all_token_ids
[:
block_size
]),
((
"aaa"
,
11
),),
)
)
)
assert
block_hashes
[
1
]
==
sha256
(
assert
block_hashes
[
1
]
==
sha256
(
(
(
block_hashes
[
0
],
block_hashes
[
0
],
tuple
(
all_token_ids
[
block_size
:
block_size
*
2
]),
tuple
(
all_token_ids
[
block_size
:
block_size
*
2
]),
(
"aaa"
,
"bbb"
),
(
(
"aaa"
,
-
5
),
(
"bbb"
,
14
)
),
)
)
)
)
assert
block_hashes
[
2
]
==
sha256
(
assert
block_hashes
[
2
]
==
sha256
(
(
(
block_hashes
[
1
],
block_hashes
[
1
],
tuple
(
all_token_ids
[
block_size
*
2
:
block_size
*
3
]),
tuple
(
all_token_ids
[
block_size
*
2
:
block_size
*
3
]),
(
"bbb"
,),
(
(
"bbb"
,
-
2
),
),
)
)
)
)
...
@@ -1603,7 +1607,11 @@ def test_mm_prefix_caching():
...
@@ -1603,7 +1607,11 @@ def test_mm_prefix_caching():
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
assert
len
(
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
assert
block_hashes
[
3
]
==
sha256
(
assert
block_hashes
[
3
]
==
sha256
(
(
block_hashes
[
2
],
tuple
(
all_token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
(
"ccc"
,))
(
block_hashes
[
2
],
tuple
(
all_token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
((
"ccc"
,
0
),),
)
)
)
# Cache hit.
# Cache hit.
...
...
vllm/v1/core/kv_cache_utils.py
View file @
269bf46d
...
@@ -413,7 +413,7 @@ def _gen_mm_extra_hash_keys(
...
@@ -413,7 +413,7 @@ def _gen_mm_extra_hash_keys(
# We do not need to check all mm inputs if the start token index is out of
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
# range. This usually happens in the late prefill phase and decoding phase.
last_pos
=
mm_features
[
-
1
].
mm_position
last_pos
=
mm_features
[
-
1
].
mm_position
if
last_pos
.
offset
+
last_pos
.
length
<
start_token_idx
:
if
last_pos
.
offset
+
last_pos
.
length
<
=
start_token_idx
:
return
extra_keys
,
start_mm_idx
return
extra_keys
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
# Support start_mm_idx == -1 to indicate the last mm input.
...
@@ -428,13 +428,16 @@ def _gen_mm_extra_hash_keys(
...
@@ -428,13 +428,16 @@ def _gen_mm_extra_hash_keys(
offset
=
mm_feature
.
mm_position
.
offset
offset
=
mm_feature
.
mm_position
.
offset
length
=
mm_feature
.
mm_position
.
length
length
=
mm_feature
.
mm_position
.
length
if
end_token_idx
>
offset
:
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
if
start_token_idx
>
=
offset
+
length
:
# This block has passed the current mm input.
# This block has passed the current mm input.
curr_mm_idx
+=
1
curr_mm_idx
+=
1
continue
continue
# The block contains the current mm input.
# The block contains the current mm input. Include its offset
extra_keys
.
append
(
mm_feature
.
identifier
)
# relative to the start of the block so prefix-cache keys stay
# distinct when the same MM item appears at different positions
# within otherwise-identical placeholder blocks.
extra_keys
.
append
((
mm_feature
.
identifier
,
offset
-
start_token_idx
))
if
end_token_idx
>=
offset
+
length
:
if
end_token_idx
>=
offset
+
length
:
# If this block contains the end of the current mm input,
# If this block contains the end of the current mm input,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment