Commit 63b35c93 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

compiles for all h but 128

parent c29f7313
...@@ -597,10 +597,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -597,10 +597,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV);
Tensor tKgK = make_tensor(tKgK_.data(), unsqueeze<2>(layout<0>(tKgK_.layout())));
Tensor tKsK = make_tensor(tKsK_.data(), unsqueeze<2>(layout<0>(tKsK_.layout())));
Tensor tVgV = make_tensor(tVgV_.data(), unsqueeze<2>(layout<0>(tVgV_.layout())));
Tensor tVsV = make_tensor(tVsV_.data(), unsqueeze<2>(layout<0>(tVsV_.layout())));
if (block_table != nullptr) { if (block_table != nullptr) {
tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size, tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
...@@ -708,8 +713,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -708,8 +713,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.vnew_row_stride, _1{})); make_stride(params.vnew_row_stride, _1{}));
Tensor tKgKnew = gmem_thr_copy_KV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new;
Tensor tVgVnew = gmem_thr_copy_KV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx);
Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
auto tKgKnew = make_tensor(tKgKnew_.data(), unsqueeze<2>(layout<0>(tKgKnew_.layout())));
auto tVgVnew = make_tensor(tVgVnew_.data(), unsqueeze<2>(layout<0>(tVgVnew_.layout())));
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data(); auto tKgK_data = tKgK.data();
......
...@@ -323,7 +323,6 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const ...@@ -323,7 +323,6 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
const int* block_table, const int page_stride, const int row_stride) { const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
...@@ -345,6 +344,14 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const ...@@ -345,6 +344,14 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int N, class Shape, class Stride>
__forceinline__ __device__ constexpr auto unsqueeze(Layout<Shape, Stride> l) {
return make_layout(insert<N>(l.shape(), Int<1>{}),
insert<N>(l.stride(), Int<0>{}));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment