Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98aa16ff
Unverified
Commit
98aa16ff
authored
Aug 26, 2025
by
Russell Bryant
Committed by
GitHub
Aug 26, 2025
Browse files
[v1] Add cross-attention KV cache support for encoder-decoder models (#23664)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
227e231b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
153 additions
and
14 deletions
+153
-14
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+19
-0
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+25
-9
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+4
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+36
-1
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+54
-2
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+15
-0
No files found.
vllm/multimodal/registry.py
View file @
98aa16ff
...
@@ -372,3 +372,22 @@ class MultiModalRegistry:
...
@@ -372,3 +372,22 @@ class MultiModalRegistry:
)
)
return
dummy_data
return
dummy_data
def
get_encdec_max_encoder_len
(
self
,
model_config
:
"ModelConfig"
)
->
int
:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
if
not
model_config
.
is_encoder_decoder
:
return
0
max_tokens
=
self
.
\
get_max_tokens_per_item_by_nonzero_modality
(
model_config
)
if
not
max_tokens
:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
# than whisper.
return
0
assert
len
(
max_tokens
)
==
1
,
"Encoder-decoder models are expected
\
to implement the multimodal interface with at most one modality."
first_modality
=
next
(
iter
(
max_tokens
))
return
max_tokens
[
first_modality
]
vllm/v1/core/kv_cache_coordinator.py
View file @
98aa16ff
...
@@ -6,7 +6,7 @@ from typing import Optional
...
@@ -6,7 +6,7 @@ from typing import Optional
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.single_type_kv_cache_manager
import
(
from
vllm.v1.core.single_type_kv_cache_manager
import
(
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
CrossAttentionManager
,
FullAttentionManager
,
get_manager_for_kv_cache_spec
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -42,9 +42,10 @@ class KVCacheCoordinator(ABC):
...
@@ -42,9 +42,10 @@ class KVCacheCoordinator(ABC):
)
for
i
,
kv_cache_group
in
enumerate
(
)
for
i
,
kv_cache_group
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
))
self
.
kv_cache_config
.
kv_cache_groups
))
def
get_num_blocks_to_allocate
(
def
get_num_blocks_to_allocate
(
self
,
request_id
:
str
,
num_tokens
:
int
,
self
,
request_id
:
str
,
num_tokens
:
int
,
new_computed_blocks
:
tuple
[
new_computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...])
->
int
:
list
[
KVCacheBlock
],
...],
num_encoder_tokens
:
int
)
->
int
:
"""
"""
Get the number of blocks needed to be allocated for the request.
Get the number of blocks needed to be allocated for the request.
...
@@ -54,14 +55,22 @@ class KVCacheCoordinator(ABC):
...
@@ -54,14 +55,22 @@ class KVCacheCoordinator(ABC):
tokens that are already allocated).
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
Returns:
The number of blocks.
The number of blocks.
"""
"""
num_blocks_to_allocate
=
0
num_blocks_to_allocate
=
0
for
i
,
manager
in
enumerate
(
self
.
single_type_managers
):
for
i
,
manager
in
enumerate
(
self
.
single_type_managers
):
num_blocks_to_allocate
+=
manager
.
get_num_blocks_to_allocate
(
if
isinstance
(
manager
,
CrossAttentionManager
):
request_id
,
num_tokens
,
new_computed_blocks
[
i
])
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate
+=
manager
.
get_num_blocks_to_allocate
(
request_id
,
num_encoder_tokens
,
[])
else
:
num_blocks_to_allocate
+=
manager
.
get_num_blocks_to_allocate
(
request_id
,
num_tokens
,
new_computed_blocks
[
i
])
return
num_blocks_to_allocate
return
num_blocks_to_allocate
def
save_new_computed_blocks
(
def
save_new_computed_blocks
(
...
@@ -79,8 +88,11 @@ class KVCacheCoordinator(ABC):
...
@@ -79,8 +88,11 @@ class KVCacheCoordinator(ABC):
manager
.
save_new_computed_blocks
(
request_id
,
manager
.
save_new_computed_blocks
(
request_id
,
new_computed_blocks
[
i
])
new_computed_blocks
[
i
])
def
allocate_new_blocks
(
self
,
request_id
:
str
,
def
allocate_new_blocks
(
num_tokens
:
int
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
self
,
request_id
:
str
,
num_tokens
:
int
,
num_encoder_tokens
:
int
=
0
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
"""
"""
Allocate new blocks for the request to give it at least `num_tokens`
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
token slots.
...
@@ -89,12 +101,16 @@ class KVCacheCoordinator(ABC):
...
@@ -89,12 +101,16 @@ class KVCacheCoordinator(ABC):
request_id: The request ID.
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
tokens that are already allocated).
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
Returns:
The new allocated blocks.
The new allocated blocks.
"""
"""
return
tuple
(
return
tuple
(
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
)
manager
.
allocate_new_blocks
(
request_id
,
num_encoder_tokens
if
isinstance
(
manager
,
CrossAttentionManager
)
else
num_tokens
)
for
manager
in
self
.
single_type_managers
)
for
manager
in
self
.
single_type_managers
)
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
...
...
vllm/v1/core/kv_cache_manager.py
View file @
98aa16ff
...
@@ -187,6 +187,7 @@ class KVCacheManager:
...
@@ -187,6 +187,7 @@ class KVCacheManager:
new_computed_blocks
:
Optional
[
KVCacheBlocks
]
=
None
,
new_computed_blocks
:
Optional
[
KVCacheBlocks
]
=
None
,
num_lookahead_tokens
:
int
=
0
,
num_lookahead_tokens
:
int
=
0
,
delay_cache_blocks
:
bool
=
False
,
delay_cache_blocks
:
bool
=
False
,
num_encoder_tokens
:
int
=
0
,
)
->
Optional
[
KVCacheBlocks
]:
)
->
Optional
[
KVCacheBlocks
]:
"""Add slots for a request with new tokens to append.
"""Add slots for a request with new tokens to append.
...
@@ -253,6 +254,7 @@ class KVCacheManager:
...
@@ -253,6 +254,7 @@ class KVCacheManager:
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
num_tokens
=
num_tokens_need_slot
,
num_tokens
=
num_tokens_need_slot
,
new_computed_blocks
=
new_computed_block_list
,
new_computed_blocks
=
new_computed_block_list
,
num_encoder_tokens
=
num_encoder_tokens
,
)
)
if
num_blocks_to_allocate
>
self
.
block_pool
.
get_num_free_blocks
():
if
num_blocks_to_allocate
>
self
.
block_pool
.
get_num_free_blocks
():
...
@@ -273,7 +275,7 @@ class KVCacheManager:
...
@@ -273,7 +275,7 @@ class KVCacheManager:
new_computed_block_list
)
new_computed_block_list
)
new_blocks
=
self
.
coordinator
.
allocate_new_blocks
(
new_blocks
=
self
.
coordinator
.
allocate_new_blocks
(
request
.
request_id
,
num_tokens_need_slot
)
request
.
request_id
,
num_tokens_need_slot
,
num_encoder_tokens
)
# P/D: delay caching blocks if we have to recv from
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
# remote. Update state for locally cached blocks.
...
@@ -292,7 +294,7 @@ class KVCacheManager:
...
@@ -292,7 +294,7 @@ class KVCacheManager:
def
free
(
self
,
request
:
Request
)
->
None
:
def
free
(
self
,
request
:
Request
)
->
None
:
"""Free the blocks allocated for the request.
"""Free the blocks allocated for the request.
We free the blocks in reverse order so that he tail blocks are evicted
We free the blocks in reverse order so that
t
he tail blocks are evicted
first when caching is enabled.
first when caching is enabled.
Args:
Args:
...
...
vllm/v1/core/sched/scheduler.py
View file @
98aa16ff
...
@@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface):
...
@@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface):
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
structured_output_manager
=
structured_output_manager
self
.
structured_output_manager
=
structured_output_manager
self
.
is_encoder_decoder
=
vllm_config
.
model_config
.
is_encoder_decoder
# include_finished_set controls whether a separate set of finished
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# request ids should be included in the EngineCoreOutputs returned
...
@@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface):
...
@@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface):
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
assert
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
1
,
(
"Multiple KV cache groups are not currently supported "
"Multiple KV cache groups are not currently supported "
"with KV connectors"
)
"with KV connectors"
)
assert
not
self
.
is_encoder_decoder
,
(
"Encoder-decoder models are not currently supported "
"with KV connectors"
)
self
.
connector
=
KVConnectorFactory
.
create_connector
(
self
.
connector
=
KVConnectorFactory
.
create_connector
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
...
@@ -431,6 +435,22 @@ class Scheduler(SchedulerInterface):
...
@@ -431,6 +435,22 @@ class Scheduler(SchedulerInterface):
==
0
else
==
0
else
self
.
num_lookahead_tokens
)
self
.
num_lookahead_tokens
)
# Determine if we need to allocate cross-attention blocks.
if
self
.
is_encoder_decoder
and
request
.
has_encoder_inputs
:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
assert
(
"whisper"
in
self
.
vllm_config
.
model_config
.
model
.
lower
()),
(
"Whisper is the only supported "
"encoder-decoder model."
)
num_encoder_tokens
=
MULTIMODAL_REGISTRY
.
\
get_encdec_max_encoder_len
(
self
.
vllm_config
.
model_config
)
else
:
num_encoder_tokens
=
0
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
request
,
num_new_tokens
+
num_external_computed_tokens
,
num_new_tokens
+
num_external_computed_tokens
,
...
@@ -438,6 +458,7 @@ class Scheduler(SchedulerInterface):
...
@@ -438,6 +458,7 @@ class Scheduler(SchedulerInterface):
new_computed_blocks
,
new_computed_blocks
,
num_lookahead_tokens
=
effective_lookahead_tokens
,
num_lookahead_tokens
=
effective_lookahead_tokens
,
delay_cache_blocks
=
load_kv_async
,
delay_cache_blocks
=
load_kv_async
,
num_encoder_tokens
=
num_encoder_tokens
,
)
)
if
new_blocks
is
None
:
if
new_blocks
is
None
:
...
@@ -703,7 +724,21 @@ class Scheduler(SchedulerInterface):
...
@@ -703,7 +724,21 @@ class Scheduler(SchedulerInterface):
# The encoder input is not needed in this step.
# The encoder input is not needed in this step.
break
break
if
start_pos
+
num_encoder_tokens
<=
num_computed_tokens
:
if
self
.
is_encoder_decoder
and
num_computed_tokens
>
0
:
assert
start_pos
==
0
,
(
"Encoder input should be processed at the beginning of "
"the sequence when encoder-decoder models are used."
)
# Encoder input has already been computed
# The calculation here is a bit different. We don't turn encoder
# output into tokens that get processed by the decoder and
# reflected in num_computed_tokens. Instead, start_pos reflects
# the position where we need to ensure we calculate encoder
# inputs. This should always be 0 to ensure we calculate encoder
# inputs before running the decoder. Once we've calculated some
# decoder tokens (num_computed_tokens > 0), then we know we
# already calculated encoder inputs and can skip here.
continue
elif
start_pos
+
num_encoder_tokens
<=
num_computed_tokens
:
# The encoder input is already computed and stored
# The encoder input is already computed and stored
# in the decoder's KV cache.
# in the decoder's KV cache.
continue
continue
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
98aa16ff
...
@@ -8,8 +8,9 @@ from vllm.utils import cdiv
...
@@ -8,8 +8,9 @@ from vllm.utils import cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
FullAttentionSpec
,
KVCacheSpec
,
CrossAttentionSpec
,
FullAttentionSpec
,
MambaSpec
,
SlidingWindowSpec
)
KVCacheSpec
,
MambaSpec
,
SlidingWindowSpec
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -552,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager):
...
@@ -552,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager):
return
new_blocks
return
new_blocks
class
CrossAttentionManager
(
SingleTypeKVCacheManager
):
"""Manager for cross-attention KV cache in encoder-decoder models."""
def
save_new_computed_blocks
(
self
,
request_id
:
str
,
new_computed_blocks
:
list
[
KVCacheBlock
])
->
None
:
# We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty.
assert
len
(
new_computed_blocks
)
==
0
def
cache_blocks
(
self
,
request
:
Request
,
num_tokens
:
int
)
->
None
:
# We do not cache blocks for cross-attention to be shared between
# requests, so this method is not relevant.
raise
ValueError
(
"Should not be called as prefix caching is disabled."
)
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
num_running_requests
:
int
)
->
int
:
# Cross-attention blocks contain request-specific encoder states
# and are not shared between different requests
return
0
@
classmethod
def
find_longest_cache_hit
(
cls
,
block_hashes
:
list
[
BlockHash
],
max_length
:
int
,
kv_cache_group_ids
:
list
[
int
],
block_pool
:
BlockPool
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
assert
isinstance
(
kv_cache_spec
,
CrossAttentionSpec
),
(
"CrossAttentionManager can only be used for cross-attention groups"
)
# Cross-attention does not benefit from prefix caching since:
# 1. Encoder states are unique per request (different audio/image
# inputs)
# 2. Encoder states are computed once per request, not incrementally
# 3. No reusable prefix exists between different multimodal inputs
# Return empty blocks to indicate no cache hits
raise
NotImplementedError
(
"CrossAttentionManager does not support caching"
)
def
remove_skipped_blocks
(
self
,
request_id
:
str
,
num_computed_tokens
:
int
)
->
None
:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
FullAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
ChunkedLocalAttentionSpec
:
ChunkedLocalAttentionManager
,
ChunkedLocalAttentionSpec
:
ChunkedLocalAttentionManager
,
MambaSpec
:
MambaManager
,
MambaSpec
:
MambaManager
,
CrossAttentionSpec
:
CrossAttentionManager
,
}
}
...
...
vllm/v1/kv_cache_interface.py
View file @
98aa16ff
...
@@ -11,6 +11,7 @@ from typing_extensions import Self
...
@@ -11,6 +11,7 @@ from typing_extensions import Self
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.utils
import
cdiv
,
get_dtype_size
from
vllm.utils
import
cdiv
,
get_dtype_size
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
...
@@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
return
0
return
0
@
dataclass
(
frozen
=
True
)
class
CrossAttentionSpec
(
AttentionSpec
):
"""
KV cache spec for cross-attention layers in encoder-decoder models.
"""
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
# For cross-attention, we need to cache encoder states
# Get encoder length (e.g., 1500 for Whisper).
max_encoder_len
=
MULTIMODAL_REGISTRY
.
\
get_encdec_max_encoder_len
(
vllm_config
.
model_config
)
return
cdiv
(
max_encoder_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
dataclass
@
dataclass
class
KVCacheTensor
:
class
KVCacheTensor
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment