Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
53c6eb1f
Commit
53c6eb1f
authored
Feb 13, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
add page size 16 to tests
parent
04aabfb7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
10 deletions
+19
-10
tests/test_flash_attn.py
tests/test_flash_attn.py
+19
-10
No files found.
tests/test_flash_attn.py
View file @
53c6eb1f
...
@@ -1818,24 +1818,24 @@ def test_flash_attn_splitkv(
...
@@ -1818,24 +1818,24 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("num_splits", [1])
# @pytest.mark.parametrize("num_splits", [1])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [False])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [False])
# @pytest.mark.parametrize("alibi", [False])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
,
48
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
,
256
,
512
])
#
@pytest.mark.parametrize("has_batch_idx", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
])
#
@pytest.mark.parametrize("has_batch_idx", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...
@@ -1844,8 +1844,17 @@ def test_flash_attn_splitkv(
...
@@ -1844,8 +1844,17 @@ def test_flash_attn_splitkv(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
"seqlen_q,seqlen_k"
,
[
[
(
1
,
10
*
1024
),
(
1
,
128
),
(
16
,
10
*
1024
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
],
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment