Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a63157ea
Commit
a63157ea
authored
Mar 26, 2024
by
skrider
Browse files
add test for page table overflow
parent
135a1da6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
0 deletions
+44
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+44
-0
No files found.
tests/test_flash_attn.py
View file @
a63157ea
...
...
@@ -2461,3 +2461,47 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
assert
torch
.
equal
(
dv
,
dv
)
assert
torch
.
equal
(
dk
,
dk
)
assert
torch
.
equal
(
dq
,
dq
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
])
# @pytest.mark.parametrize("has_batch_idx", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"nheads"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"b"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[(
170
,
170
)])
def
test_flash_attn_paged_kvcache_overflow
(
seqlen_q
,
seqlen_k
,
d
,
nheads
,
b
,
n
,
paged_kv_block_size
,
causal
,
dtype
,
):
device
=
"cuda"
num_blocks
=
1000
*
16
//
paged_kv_block_size
key_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
cache_seqlens
=
torch
.
zeros
(
b
,
dtype
=
torch
.
int32
,
device
=
device
)
for
_
in
range
(
n
):
query
=
torch
.
rand
([
b
,
seqlen_q
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
size
=
(
b
,
(
seqlen_k
+
paged_kv_block_size
-
1
)
//
paged_kv_block_size
),
dtype
=
torch
.
int32
,
device
=
device
)
output
=
flash_attn_with_kvcache
(
query
,
key_cache
,
value_cache
,
k
=
key
,
v
=
value
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_tables
,
causal
=
causal
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment