Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0bf29fad
Unverified
Commit
0bf29fad
authored
Nov 10, 2025
by
Matthew Bonanni
Committed by
GitHub
Nov 10, 2025
Browse files
[Test] Remove old non-varlen FA2 test (#28420)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
a5a790ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
119 deletions
+0
-119
tests/kernels/attention/test_flash_attn.py
tests/kernels/attention/test_flash_attn.py
+0
-119
No files found.
tests/kernels/attention/test_flash_attn.py
View file @
0bf29fad
...
@@ -9,7 +9,6 @@ from vllm.platforms import current_platform
...
@@ -9,7 +9,6 @@ from vllm.platforms import current_platform
from
vllm.vllm_flash_attn
import
(
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
is_fa_version_supported
,
is_fa_version_supported
,
)
)
...
@@ -83,124 +82,6 @@ def ref_paged_attn(
...
@@ -83,124 +82,6 @@ def ref_paged_attn(
return
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
SOFT_CAPS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
kv_lens
:
list
[
int
],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
sliding_window
:
int
|
None
,
fa_version
:
int
,
q_dtype
:
torch
.
dtype
|
None
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
'to: "
{
fa_version_unsupported_reason
(
fa_version
)
}
"'
)
if
q_dtype
is
not
None
and
(
dtype
!=
torch
.
bfloat16
or
fa_version
==
2
):
pytest
.
skip
(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
window_size
=
(
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
)
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
q
=
query
.
unsqueeze
(
1
)
out
=
torch
.
empty_like
(
q
)
if
use_out
else
None
maybe_quantized_query
=
q
maybe_quantized_key_cache
=
key_cache
maybe_quantized_value_cache
=
value_cache
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
q_dtype
is
not
None
:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
q
.
to
(
q_dtype
)
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
scale_shape
=
(
num_seqs
,
num_kv_heads
)
q_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
k_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
output
=
flash_attn_with_kvcache
(
q
=
maybe_quantized_query
,
k_cache
=
maybe_quantized_key_cache
,
v_cache
=
maybe_quantized_value_cache
,
out
=
out
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
fa_version
=
fa_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
,
)
(
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
,
)
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)],
[(
1
,
523
),
(
1
,
37
),
(
1
,
2011
)]]
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)],
[(
1
,
523
),
(
1
,
37
),
(
1
,
2011
)]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment