Commit 58611cd9 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

resolve page offsets absolutely not relatively

parent 10b6f3a8
...@@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));
if (block_table != nullptr) { if (block_table != nullptr) {
tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size, tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride); block_table, params.k_batch_stride, params.k_row_stride);
tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size, tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
} }
...@@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
if (n_block > n_block_copy_min) { if (n_block > n_block_copy_min) {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride); block_table, params.k_batch_stride, params.k_row_stride);
} }
} }
...@@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else { } else {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size, tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
...@@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride); block_table, params.k_batch_stride, params.k_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
...@@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else { } else {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size, tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
...@@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride); block_table, params.k_batch_stride, params.k_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
......
...@@ -292,11 +292,11 @@ void cp_async_wait() { ...@@ -292,11 +292,11 @@ void cp_async_wait() {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves initial base offset of a slice of a paged kv copy from gmem. // resolves offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head. // assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits> template <typename Kernel_traits>
__forceinline__ __device__ __forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) { const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
...@@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons ...@@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
+ page_offset * row_stride + page_offset * row_stride
+ col_offset; + col_offset;
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
// advances base address of a slice of a paged copy from gmem
template <typename Kernel_traits>
__forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kBlockN = Kernel_traits::kBlockN;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int global_row_offset_cur = block_row_offset + n_block * kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN;
const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size;
const int virtual_page_idx_cur = global_row_offset_cur / page_block_size;
const int virtual_page_idx_next = global_row_offset_next / page_block_size;
const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur];
const int offset_diff = page_offset_next - page_offset_cur;
return table_diff * page_stride + offset_diff * row_stride;
}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment