Commit 3fe8479d authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

allow smaller page sizes in varlen api

parent 7e92c3b9
Pipeline #2012 failed with stages
in 0 seconds
...@@ -561,7 +561,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -561,7 +561,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0); const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1); const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
......
...@@ -1543,7 +1543,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1543,7 +1543,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
], ],
) )
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [None, 16, 256, 512])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_causal( def test_flash_attn_varlen_causal(
seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment