Commit e524c2ca authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

allow small page sizes in flash api

parent b1c18ca1
...@@ -1285,7 +1285,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1285,7 +1285,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : kcache.size(0); const int num_blocks = !paged_KV ? 0 : kcache.size(0);
const int page_block_size = !paged_KV ? 1 : kcache.size(1); const int page_block_size = !paged_KV ? 1 : kcache.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
const int num_heads_k = kcache.size(2); const int num_heads_k = kcache.size(2);
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
......
...@@ -328,12 +328,11 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const ...@@ -328,12 +328,11 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
// base row of thread's slice relative to the block
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor
const int global_row_offset_cur = block_row_offset + n_block * kBlockN; const int global_row_offset_cur = block_row_offset + n_block * kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN; const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN;
// base row of thread's slice relative to the page
const int page_offset_cur = global_row_offset_cur % page_block_size; const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size; const int page_offset_next = global_row_offset_next % page_block_size;
......
...@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv( ...@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv(
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [256]) @pytest.mark.parametrize("paged_kv_block_size", [16, 256, 512])
# @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment