Commit c29f7313 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

tidy flash_fwd_kernel

parent 3f2484ee
......@@ -566,7 +566,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread
const index_t row_offset_k__shadow = block_table[(n_block_max - 1) * kBlockN / params.page_block_size] * params.k_batch_stride + (((n_block_max - 1) * kBlockN) % params.page_block_size) * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
......@@ -580,9 +579,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gK__shadow = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k__shadow),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
......@@ -602,7 +598,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKgK__shadow = gmem_thr_copy_KV.partition_S(gK__shadow); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
......@@ -754,14 +749,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
// const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
// const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
// const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
// const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
// const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
......@@ -854,11 +841,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
// const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
// const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
// const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
// const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
// tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment