Commit b1c18ca1 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

paged copy refactor working for page size 256

parent 446204c7
......@@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
block_table, params.v_batch_stride, params.v_row_stride);
}
#if 1
KIN_PRINT([&]() {
for (int i = 0; i < n_block_max; i++) {
printf("%d ", block_table[i]);
}
}())
// KIN_PRINT([&]() {
// for (int i = 0; i < n_block_max; i++) {
// printf("%d ", block_table[i]);
// }
// }())
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (thread0()) {
// gK.data() = tKgK.data();
// }
KIN_PRINT(print_tensor(tKgK))
KIN_PRINT(print_tensor(gK))
KIN_PRINT(print_tensor(tKgK__shadow))
KIN_PRINT(print_tensor(gK__shadow))
// KIN_PRINT(print_tensor(tKgK))
// KIN_PRINT(print_tensor(gK))
// KIN_PRINT(print_tensor(tKgK__shadow))
// KIN_PRINT(print_tensor(gK__shadow))
#endif
typename Kernel_traits::TiledMma tiled_mma;
......@@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride);
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
}
}
......@@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
// const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
// const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
// const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
// const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
// tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
......@@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table,
params.k_batch_stride, params.k_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
......@@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
......
......@@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size;
KIN_PRINT(printf("%d", virtual_page_idx))
return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride
......@@ -324,12 +323,16 @@ template <typename Kernel_traits>
__forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
return 0;
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;
// base row of thread's slice relative to the block
const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor
const int global_row_offset_cur = block_row_offset + n_block * Kernel_traits::kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * Kernel_traits::kBlockN;
const int global_row_offset_cur = block_row_offset + n_block * kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN;
// base row of thread's slice relative to the page
const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment