Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
281710ef
Unverified
Commit
281710ef
authored
Aug 22, 2025
by
Russell Bryant
Committed by
GitHub
Aug 22, 2025
Browse files
[Attention] Allow V1 flash_attn to support cross-attention (#23297)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
808d2e9a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
10 deletions
+7
-10
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+7
-10
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
281710ef
...
...
@@ -405,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl):
FlashAttentionBackend
.
validate_head_size
(
head_size
)
if
attn_type
not
in
[
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_ONLY
]:
raise
NotImplementedError
(
"Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl"
)
self
.
attn_type
=
attn_type
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
\
...
...
@@ -477,7 +470,7 @@ class FlashAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if
attn_type
in
(
AttentionType
.
ENCODER_ONLY
,
):
if
attn_type
in
(
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
ENCODER
):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
self
.
_forward_encoder_attention
(
query
[:
num_actual_tokens
],
...
...
@@ -489,7 +482,11 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_sharing_target_layer_name
is
None
:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
(
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
...
...
@@ -528,7 +525,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
]
)
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
self
.
num_kv_heads
)
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment