Commit a3e06cd5 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

rearrange initial offset computation

parent f67a6edf
......@@ -624,13 +624,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: init_thread_kv_page_slice_offset<Kernel_traits>(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride, params.k_head_stride);
: (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: init_thread_kv_page_slice_offset<Kernel_traits>(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table,
params.v_batch_stride, params.v_row_stride, params.v_head_stride);
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
......@@ -667,6 +665,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
if (block_table != nullptr) {
tKgK.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
#if 1
KIN_PRINT(print(tKgK.layout()))
KIN_PRINT(print(tKsK.layout()))
......@@ -850,9 +856,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride);
}
}
......@@ -977,7 +983,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
......@@ -1017,7 +1023,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
......@@ -1035,7 +1041,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
......
......@@ -292,11 +292,12 @@ void cp_async_wait() {
////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves initial base address of a slice of a paged kv copy from gmem
// resolves initial base offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride, const int head_stride) {
int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
// base col of thread's slice relative to the block
const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad;
// base row of thread's slice relative to the block
......@@ -310,7 +311,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n
return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride
+ hidx * head_stride
+ col_offset;
}
......@@ -321,6 +321,7 @@ template <typename Kernel_traits>
__forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
return 0;
// base row of thread's slice relative to the block
const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment