Commit c05b8570 authored by skrider's avatar skrider
Browse files

allow small page sizes in flash api

parent 36916777
......@@ -1285,7 +1285,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
const int num_heads_k = kcache.size(2);
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
......
......@@ -328,12 +328,11 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;
// base row of thread's slice relative to the block
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor
const int global_row_offset_cur = block_row_offset + n_block * kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN;
// base row of thread's slice relative to the page
const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size;
......
......@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv(
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [256])
@pytest.mark.parametrize("paged_kv_block_size", [16, 256, 512])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment