"tests/models/autoencoders/test_models_vae.py" did not exist on "63f767ef15fa59704272ac7320ec23b8c15de246"
Commit b1c18ca1 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

paged copy refactor working for page size 256

parent 446204c7
...@@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
block_table, params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
} }
#if 1 #if 1
KIN_PRINT([&]() { // KIN_PRINT([&]() {
for (int i = 0; i < n_block_max; i++) { // for (int i = 0; i < n_block_max; i++) {
printf("%d ", block_table[i]); // printf("%d ", block_table[i]);
} // }
}()) // }())
// if (tidx == 8) fill(tKgK, 1.f * tidx); // if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (thread0()) { // if (thread0()) {
// gK.data() = tKgK.data(); // gK.data() = tKgK.data();
// } // }
KIN_PRINT(print_tensor(tKgK)) // KIN_PRINT(print_tensor(tKgK))
KIN_PRINT(print_tensor(gK)) // KIN_PRINT(print_tensor(gK))
KIN_PRINT(print_tensor(tKgK__shadow)) // KIN_PRINT(print_tensor(tKgK__shadow))
KIN_PRINT(print_tensor(gK__shadow)) // KIN_PRINT(print_tensor(gK__shadow))
#endif #endif
typename Kernel_traits::TiledMma tiled_mma; typename Kernel_traits::TiledMma tiled_mma;
...@@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// const int offset_diff = block_table_offset_next - block_table_offset_cur; // const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; // tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; // tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table, tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table, tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
params.k_batch_stride, params.k_row_stride); block_table, params.k_batch_stride, params.k_row_stride);
} }
} }
} }
...@@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else { } else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; // const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; // const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size; // const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; // const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; // tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
...@@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else { } else {
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, block_table, tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
params.k_batch_stride, params.k_row_stride); block_table, params.k_batch_stride, params.k_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
...@@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) { if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else { } else {
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size, tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride); block_table, params.v_batch_stride, params.v_row_stride);
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
......
...@@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons ...@@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int page_offset = global_row_offset % page_block_size; const int page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size; const int virtual_page_idx = global_row_offset / page_block_size;
KIN_PRINT(printf("%d", virtual_page_idx))
return block_table[virtual_page_idx] * page_stride return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride + page_offset * row_stride
...@@ -324,12 +323,16 @@ template <typename Kernel_traits> ...@@ -324,12 +323,16 @@ template <typename Kernel_traits>
__forceinline__ __device__ __forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) { const int* block_table, const int page_stride, const int row_stride) {
return 0; constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;
// base row of thread's slice relative to the block // base row of thread's slice relative to the block
const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread; const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor // base col of thread's slice relative to the entire tensor
const int global_row_offset_cur = block_row_offset + n_block * Kernel_traits::kBlockN; const int global_row_offset_cur = block_row_offset + n_block * kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * Kernel_traits::kBlockN; const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN;
// base row of thread's slice relative to the page // base row of thread's slice relative to the page
const int page_offset_cur = global_row_offset_cur % page_block_size; const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size; const int page_offset_next = global_row_offset_next % page_block_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment