Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98aa16ff
Unverified
Commit
98aa16ff
authored
Aug 26, 2025
by
Russell Bryant
Committed by
GitHub
Aug 26, 2025
Browse files
[v1] Add cross-attention KV cache support for encoder-decoder models (#23664)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
227e231b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
153 additions
and
14 deletions
+153
-14
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+19
-0
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+25
-9
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+4
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+36
-1
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+54
-2
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+15
-0
No files found.
vllm/multimodal/registry.py
View file @
98aa16ff
...
...
@@ -372,3 +372,22 @@ class MultiModalRegistry:
)
return
dummy_data
def
get_encdec_max_encoder_len
(
self
,
model_config
:
"ModelConfig"
)
->
int
:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
if
not
model_config
.
is_encoder_decoder
:
return
0
max_tokens
=
self
.
\
get_max_tokens_per_item_by_nonzero_modality
(
model_config
)
if
not
max_tokens
:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
# than whisper.
return
0
assert
len
(
max_tokens
)
==
1
,
"Encoder-decoder models are expected
\
to implement the multimodal interface with at most one modality."
first_modality
=
next
(
iter
(
max_tokens
))
return
max_tokens
[
first_modality
]
vllm/v1/core/kv_cache_coordinator.py
View file @
98aa16ff
...
...
@@ -6,7 +6,7 @@ from typing import Optional
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
CrossAttentionManager
,
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.request
import
Request
...
...
@@ -42,9 +42,10 @@ class KVCacheCoordinator(ABC):
)
for
i
,
kv_cache_group
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
))
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...])
->
int
:
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...],
num_encoder_tokens
:
int
)
->
int
:
"""
Get the number of blocks needed to be allocated for the request.
...
...
@@ -54,12 +55,20 @@ class KVCacheCoordinator(ABC):
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
The number of blocks.
"""
num_blocks_to_allocate
=
0
for
i
,
manager
in
enumerate
(
self
.
single_type_managers
):
if
isinstance
(
manager
,
CrossAttentionManager
):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate
+=
manager
.
get_num_blocks_to_allocate
(
request_id
,
num_encoder_tokens
,
[])
else
:
num_blocks_to_allocate
+=
manager
.
get_num_blocks_to_allocate
(
request_id
,
num_tokens
,
new_computed_blocks
[
i
])
return
num_blocks_to_allocate
...
...
@@ -79,8 +88,11 @@ class KVCacheCoordinator(ABC):
manager
.
save_new_computed_blocks
(
request_id
,
new_computed_blocks
[
i
])
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
,
num_encoder_tokens
:
int
=
0
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
...
...
@@ -89,12 +101,16 @@ class KVCacheCoordinator(ABC):
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
The new allocated blocks.
"""
return
tuple
(
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
)
manager
.
allocate_new_blocks
(
request_id
,
num_encoder_tokens
if
isinstance
(
manager
,
CrossAttentionManager
)
else
num_tokens
)
for
manager
in
self
.
single_type_managers
)
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
...
...
vllm/v1/core/kv_cache_manager.py
View file @
98aa16ff
...
...
@@ -187,6 +187,7 @@ class KVCacheManager:
new_computed_blocks
:
Optional
[
KVCacheBlocks
]
=
None
,
num_lookahead_tokens
:
int
=
0
,
delay_cache_blocks
:
bool
=
False
,
num_encoder_tokens
:
int
=
0
,
)
->
Optional
[
KVCacheBlocks
]:
"""Add slots for a request with new tokens to append.
...
...
@@ -253,6 +254,7 @@ class KVCacheManager:
request_id
=
request
.
request_id
,
num_tokens
=
num_tokens_need_slot
,
new_computed_blocks
=
new_computed_block_list
,
num_encoder_tokens
=
num_encoder_tokens
,
)
if
num_blocks_to_allocate
>
self
.
block_pool
.
get_num_free_blocks
():
...
...
@@ -273,7 +275,7 @@ class KVCacheManager:
new_computed_block_list
)
new_blocks
=
self
.
coordinator
.
allocate_new_blocks
(
request
.
request_id
,
num_tokens_need_slot
)
request
.
request_id
,
num_tokens_need_slot
,
num_encoder_tokens
)
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
...
...
@@ -292,7 +294,7 @@ class KVCacheManager:
def
free
(
self
,
request
:
Request
)
->
None
:
"""Free the blocks allocated for the request.
We free the blocks in reverse order so that he tail blocks are evicted
We free the blocks in reverse order so that
t
he tail blocks are evicted
first when caching is enabled.
Args:
...
...
vllm/v1/core/sched/scheduler.py
View file @
98aa16ff
...
...
@@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface):
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
log_stats
=
log_stats
self
.
structured_output_manager
=
structured_output_manager
self
.
is_encoder_decoder
=
vllm_config
.
model_config
.
is_encoder_decoder
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
...
...
@@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface):
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
"Multiple KV cache groups are not currently supported "
"with KV connectors"
)
assert
not
self
.
is_encoder_decoder
,
(
"Encoder-decoder models are not currently supported "
"with KV connectors"
)
self
.
connector
=
KVConnectorFactory
.
create_connector
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
...
...
@@ -431,6 +435,22 @@ class Scheduler(SchedulerInterface):
==
0
else
self
.
num_lookahead_tokens
)
# Determine if we need to allocate cross-attention blocks.
if
self
.
is_encoder_decoder
and
request
.
has_encoder_inputs
:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
assert
(
"whisper"
in
self
.
vllm_config
.
model_config
.
model
.
lower
()),
(
"Whisper is the only supported "
"encoder-decoder model."
)
num_encoder_tokens
=
MULTIMODAL_REGISTRY
.
\
get_encdec_max_encoder_len
(
self
.
vllm_config
.
model_config
)
else
:
num_encoder_tokens
=
0
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
+
num_external_computed_tokens
,
...
...
@@ -438,6 +458,7 @@ class Scheduler(SchedulerInterface):
new_computed_blocks
,
num_lookahead_tokens
=
effective_lookahead_tokens
,
delay_cache_blocks
=
load_kv_async
,
num_encoder_tokens
=
num_encoder_tokens
,
)
if
new_blocks
is
None
:
...
...
@@ -703,7 +724,21 @@ class Scheduler(SchedulerInterface):
# The encoder input is not needed in this step.
break
if
start_pos
+
num_encoder_tokens
<=
num_computed_tokens
:
if
self
.
is_encoder_decoder
and
num_computed_tokens
>
0
:
assert
start_pos
==
0
,
(
"Encoder input should be processed at the beginning of "
"the sequence when encoder-decoder models are used."
)
# Encoder input has already been computed
# The calculation here is a bit different. We don't turn encoder
# output into tokens that get processed by the decoder and
# reflected in num_computed_tokens. Instead, start_pos reflects
# the position where we need to ensure we calculate encoder
# inputs. This should always be 0 to ensure we calculate encoder
# inputs before running the decoder. Once we've calculated some
# decoder tokens (num_computed_tokens > 0), then we know we
# already calculated encoder inputs and can skip here.
continue
elif
start_pos
+
num_encoder_tokens
<=
num_computed_tokens
:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
98aa16ff
...
...
@@ -8,8 +8,9 @@ from vllm.utils import cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
FullAttentionSpec
,
KVCacheSpec
,
MambaSpec
,
SlidingWindowSpec
)
CrossAttentionSpec
,
FullAttentionSpec
,
KVCacheSpec
,
MambaSpec
,
SlidingWindowSpec
)
from
vllm.v1.request
import
Request
...
...
@@ -552,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager):
return
new_blocks
class
CrossAttentionManager
(
SingleTypeKVCacheManager
):
"""Manager for cross-attention KV cache in encoder-decoder models."""
def
save_new_computed_blocks
(
self
,
request_id
:
str
,
new_computed_blocks
:
list
[
KVCacheBlock
])
->
None
:
# We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty.
assert
len
(
new_computed_blocks
)
==
0
def
cache_blocks
(
self
,
request
:
Request
,
num_tokens
:
int
)
->
None
:
# We do not cache blocks for cross-attention to be shared between
# requests, so this method is not relevant.
raise
ValueError
(
"Should not be called as prefix caching is disabled."
)
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
num_running_requests
:
int
)
->
int
:
# Cross-attention blocks contain request-specific encoder states
# and are not shared between different requests
return
0
@
classmethod
def
find_longest_cache_hit
(
cls
,
block_hashes
:
list
[
BlockHash
],
max_length
:
int
,
kv_cache_group_ids
:
list
[
int
],
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
assert
isinstance
(
kv_cache_spec
,
CrossAttentionSpec
),
(
"CrossAttentionManager can only be used for cross-attention groups"
)
# Cross-attention does not benefit from prefix caching since:
# 1. Encoder states are unique per request (different audio/image
# inputs)
# 2. Encoder states are computed once per request, not incrementally
# 3. No reusable prefix exists between different multimodal inputs
# Return empty blocks to indicate no cache hits
raise
NotImplementedError
(
"CrossAttentionManager does not support caching"
)
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
ChunkedLocalAttentionSpec
:
ChunkedLocalAttentionManager
,
MambaSpec
:
MambaManager
,
CrossAttentionSpec
:
CrossAttentionManager
,
}
...
...
vllm/v1/kv_cache_interface.py
View file @
98aa16ff
...
...
@@ -11,6 +11,7 @@ from typing_extensions import Self
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.utils
import
cdiv
,
get_dtype_size
logger
=
init_logger
(
__name__
)
...
...
@@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
return
0
@
dataclass
(
frozen
=
True
)
class
CrossAttentionSpec
(
AttentionSpec
):
"""
KV cache spec for cross-attention layers in encoder-decoder models.
"""
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
# For cross-attention, we need to cache encoder states
# Get encoder length (e.g., 1500 for Whisper).
max_encoder_len
=
MULTIMODAL_REGISTRY
.
\
get_encdec_max_encoder_len
(
vllm_config
.
model_config
)
return
cdiv
(
max_encoder_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
class
KVCacheTensor
:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment