Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db56a599
Unverified
Commit
db56a599
authored
Nov 14, 2025
by
Lucas Wilkinson
Committed by
GitHub
Nov 14, 2025
Browse files
[BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (#28702)
parent
9324e102
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
2 deletions
+5
-2
tests/kernels/attention/test_cascade_flash_attn.py
tests/kernels/attention/test_cascade_flash_attn.py
+1
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-2
No files found.
tests/kernels/attention/test_cascade_flash_attn.py
View file @
db56a599
...
@@ -170,6 +170,7 @@ def test_cascade(
...
@@ -170,6 +170,7 @@ def test_cascade(
logits_soft_cap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
logits_soft_cap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
block_table
=
block_tables
,
block_table
=
block_tables
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
max_num_splits
=
0
,
# no max
fa_version
=
fa_version
,
fa_version
=
fa_version
,
)
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
db56a599
...
@@ -704,6 +704,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -704,6 +704,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
max_num_splits
=
attn_metadata
.
max_num_splits
,
fa_version
=
self
.
vllm_flash_attn_version
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
...
@@ -950,6 +951,7 @@ def cascade_attention(
...
@@ -950,6 +951,7 @@ def cascade_attention(
logits_soft_cap
:
float
,
logits_soft_cap
:
float
,
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
max_num_splits
:
int
,
fa_version
:
int
,
fa_version
:
int
,
prefix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
,
prefix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
,
suffix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
,
suffix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
,
...
@@ -994,7 +996,7 @@ def cascade_attention(
...
@@ -994,7 +996,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
# enabling its effect during the final attention merge.
s_aux
=
s_aux
,
s_aux
=
s_aux
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
0
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
)
)
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
...
@@ -1019,7 +1021,7 @@ def cascade_attention(
...
@@ -1019,7 +1021,7 @@ def cascade_attention(
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
0
,
num_splits
=
1
if
vllm_is_batch_invariant
()
else
max_num_splits
,
)
)
# Merge prefix and suffix outputs, and store the result in output.
# Merge prefix and suffix outputs, and store the result in output.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment