Commit f67a6edf authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

implement kv page iteration functions

parent d4da6bfc
...@@ -621,16 +621,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -621,16 +621,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We move K and V to the last block. // We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
const index_t row_offset_k = block_table == nullptr const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; : init_thread_kv_page_slice_offset<Kernel_traits>(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride, params.k_head_stride);
const index_t row_offset_v = block_table == nullptr const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; : init_thread_kv_page_slice_offset<Kernel_traits>(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table,
params.v_batch_stride, params.v_row_stride, params.v_head_stride);
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
...@@ -842,14 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -842,14 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
if (n_block > n_block_copy_min) { if (n_block > n_block_copy_min) {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; // const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; // const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; // const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; // const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; // const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur; // const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; // tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; // tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride);
} }
} }
} }
...@@ -973,11 +977,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -973,11 +977,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; params.k_batch_stride, params.k_row_stride);
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
...@@ -1016,11 +1017,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1016,11 +1017,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else { } else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; block_table, params.v_batch_stride, params.v_row_stride);
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -1037,11 +1035,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1037,11 +1035,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; block_table, params.k_batch_stride, params.k_row_stride);
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
......
...@@ -292,6 +292,55 @@ void cp_async_wait() { ...@@ -292,6 +292,55 @@ void cp_async_wait() {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves initial base address of a slice of a paged kv copy from gmem
template <typename Kernel_traits>
__forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride, const int head_stride) {
// base col of thread's slice relative to the block
const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad;
// base row of thread's slice relative to the block
const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor
const int global_row_offset = block_row_offset + (n_block_max - 1) * Kernel_traits::kBlockN;
// base row of thread's slice relative to the page
const int page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size;
return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride
+ hidx * head_stride
+ col_offset;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// advances base address of a slice of a paged copy from gmem
template <typename Kernel_traits>
__forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
// base row of thread's slice relative to the block
const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor
const int global_row_offset_cur = block_row_offset + n_block * Kernel_traits::kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * Kernel_traits::kBlockN;
// base row of thread's slice relative to the page
const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size;
const int virtual_page_idx_cur = global_row_offset_cur / page_block_size;
const int virtual_page_idx_next = global_row_offset_next / page_block_size;
const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur];
const int offset_diff = page_offset_next - page_offset_cur;
return table_diff * page_stride + offset_diff * row_stride;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment