Commit 53c6eb1f authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

add page size 16 to tests

parent 04aabfb7
...@@ -1818,24 +1818,24 @@ def test_flash_attn_splitkv( ...@@ -1818,24 +1818,24 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("num_splits", [1]) # @pytest.mark.parametrize("num_splits", [1])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False]) # @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [False]) # @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False]) # @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [16, 48, 256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [16, 256, 512])
# @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False]) # @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...@@ -1844,8 +1844,17 @@ def test_flash_attn_splitkv( ...@@ -1844,8 +1844,17 @@ def test_flash_attn_splitkv(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seqlen_q,seqlen_k", "seqlen_q,seqlen_k",
[ [
(1, 10 * 1024), (1, 128),
(16, 10 * 1024), (1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
], ],
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment