Commit a3e06cd5 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

rearrange initial offset computation

parent f67a6edf
...@@ -624,13 +624,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -624,13 +624,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const index_t row_offset_k = block_table == nullptr const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: init_thread_kv_page_slice_offset<Kernel_traits>(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table, : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread
params.k_batch_stride, params.k_row_stride, params.k_head_stride);
const index_t row_offset_v = block_table == nullptr const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: init_thread_kv_page_slice_offset<Kernel_traits>(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table, : (bidh / params.h_h_k_ratio) * params.v_head_stride;
params.v_batch_stride, params.v_row_stride, params.v_head_stride);
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
...@@ -667,6 +665,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -667,6 +665,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
if (block_table != nullptr) {
tKgK.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
#if 1 #if 1
KIN_PRINT(print(tKgK.layout())) KIN_PRINT(print(tKgK.layout()))
KIN_PRINT(print(tKsK.layout())) KIN_PRINT(print(tKsK.layout()))
...@@ -850,9 +856,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -850,9 +856,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// const int offset_diff = block_table_offset_next - block_table_offset_cur; // const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; // tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; // tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table, tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.v_batch_stride, params.v_row_stride); params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table, tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride); params.k_batch_stride, params.k_row_stride);
} }
} }
...@@ -977,7 +983,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -977,7 +983,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table, tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride); params.k_batch_stride, params.k_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
...@@ -1017,7 +1023,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1017,7 +1023,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else { } else {
tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
...@@ -1035,7 +1041,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1035,7 +1041,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride); block_table, params.k_batch_stride, params.k_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
......
...@@ -292,11 +292,12 @@ void cp_async_wait() { ...@@ -292,11 +292,12 @@ void cp_async_wait() {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves initial base address of a slice of a paged kv copy from gmem // resolves initial base offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits> template <typename Kernel_traits>
__forceinline__ __device__ __forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n_block_max, const int page_block_size, int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride, const int head_stride) { const int* block_table, const int page_stride, const int row_stride) {
// base col of thread's slice relative to the block // base col of thread's slice relative to the block
const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad; const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad;
// base row of thread's slice relative to the block // base row of thread's slice relative to the block
...@@ -310,7 +311,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n ...@@ -310,7 +311,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n
return block_table[virtual_page_idx] * page_stride return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride + page_offset * row_stride
+ hidx * head_stride
+ col_offset; + col_offset;
} }
...@@ -321,6 +321,7 @@ template <typename Kernel_traits> ...@@ -321,6 +321,7 @@ template <typename Kernel_traits>
__forceinline__ __device__ __forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) { const int* block_table, const int page_stride, const int row_stride) {
return 0;
// base row of thread's slice relative to the block // base row of thread's slice relative to the block
const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread; const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor // base col of thread's slice relative to the entire tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment