Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc9f821d
Unverified
Commit
fc9f821d
authored
Nov 21, 2025
by
who who who
Committed by
GitHub
Nov 21, 2025
Browse files
fix cross attention (#28346)
Signed-off-by:
fsx950223
<
fsx950223@outlook.com
>
parent
94528630
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
8 deletions
+9
-8
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+9
-8
No files found.
vllm/v1/attention/backends/triton_attn.py
View file @
fc9f821d
...
@@ -244,14 +244,11 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -244,14 +244,11 @@ class TritonAttentionImpl(AttentionImpl):
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
not
in
[
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_
DECODER
]
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Encoder self-attention and "
"Encoder self-attention is not implemented for TritonAttentionImpl"
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl"
)
)
self
.
attn_type
=
attn_type
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
self
.
sinks
=
sinks
self
.
sinks
=
sinks
...
@@ -312,7 +309,11 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -312,7 +309,11 @@ class TritonAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
if
self
.
kv_sharing_target_layer_name
is
None
:
if
(
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# Skip this if sharing KV cache with an earlier attention layer.
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
...
@@ -346,7 +347,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -346,7 +347,7 @@ class TritonAttentionImpl(AttentionImpl):
max_seqlen_k
=
attn_metadata
.
max_seq_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
block_table
=
attn_metadata
.
block_table
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
_cache
.
shape
[
2
])
unified_attention
(
unified_attention
(
q
=
query
[:
num_actual_tokens
],
q
=
query
[:
num_actual_tokens
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment