Commit 14b190bc authored by skrider's avatar skrider
Browse files

reshape gmem copy

parent ac5e78a6
...@@ -3,18 +3,18 @@ ...@@ -3,18 +3,18 @@
#pragma once #pragma once
#define KIN_PRINT(tag, statement) \ #define KIN_PRINT(statement) \
if (thread0()) { \ if (thread0()) { \
printf("\n[kin:start:%s]\n", tag); \ printf("\n[kin:start:%s]\n", #statement); \
statement; \ statement; \
printf("\n[kin:end:%s]\n", tag); \ printf("\n[kin:end:%s]\n", #statement); \
} }
#define KIN_PRINT_BOOL(tag, BOOL) \ #define KIN_PRINT_BOOL(BOOL) \
if (thread0()) { \ if (thread0()) { \
printf("\n[kin:start:%s]\n", tag); \ printf("\n[kin:start:%s]\n", #BOOL); \
printf("%s", BOOL ? "true" : "false"); \ printf("%s", BOOL ? "true" : "false"); \
printf("\n[kin:end:%s]\n", tag); \ printf("\n[kin:end:%s]\n", #BOOL); \
} }
template<typename Kernel_traits> template<typename Kernel_traits>
...@@ -36,7 +36,17 @@ print_traits() { ...@@ -36,7 +36,17 @@ print_traits() {
printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize ); printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize );
printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize ); printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize );
printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize ); printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize );
printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread );
printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad ); printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad );
// cute object
printf("Kernel_traits::GmemLayoutAtom : ");
cute::print(Kernel_traits::GmemLayoutAtom());
printf("\n");
printf("Kernel_traits::GmemTiledCopyQKV :\n");
cute::print(Kernel_traits::GmemTiledCopyQKV());
printf("\n");
} }
template<typename BlockInfo> template<typename BlockInfo>
......
...@@ -44,7 +44,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -44,7 +44,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
#if 1 #if 1
KIN_PRINT("Kernel_traits", print_traits<Kernel_traits>()); KIN_PRINT(print_traits<Kernel_traits>());
#endif #endif
auto seed_offset = at::cuda::philox::unpack(params.philox_args); auto seed_offset = at::cuda::philox::unpack(params.philox_args);
...@@ -61,7 +61,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -61,7 +61,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
#if 1 #if 1
KIN_PRINT("binfo", print_binfo(binfo)) KIN_PRINT(print_binfo(binfo))
#endif #endif
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
...@@ -145,17 +145,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -145,17 +145,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{}); typename Kernel_traits::SmemLayoutKV{});
#if 1 #if 1
KIN_PRINT("sK.layout()", print(sK.layout())) KIN_PRINT(print(sK.layout()))
KIN_PRINT("gK.layout()", print(gK.layout())) KIN_PRINT(print(gK.layout()))
#endif #endif
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
#if 1 #if 1
KIN_PRINT("sV.layout()", print(sV.layout())) KIN_PRINT(print(sV.layout()))
KIN_PRINT("sVt.layout()", print(sVt.layout())) KIN_PRINT(print(sVt.layout()))
KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout())) KIN_PRINT(print(sVtNoSwizzle.layout()))
#endif #endif
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
...@@ -168,8 +168,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -168,8 +168,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
#if 1 #if 1
KIN_PRINT("tKgK.layout()", print(tKgK.layout())) KIN_PRINT(print(tKgK.layout()))
KIN_PRINT("tKsK.layout()", print(tKsK.layout())) KIN_PRINT(print(tKsK.layout()))
#endif #endif
typename Kernel_traits::TiledMma tiled_mma; typename Kernel_traits::TiledMma tiled_mma;
...@@ -178,15 +178,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -178,15 +178,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
#if 1 #if 1
KIN_PRINT("tSrQ.layout()", print(tSrQ.layout())) KIN_PRINT(print(tSrQ.layout()))
KIN_PRINT("tSrK.layout()", print(tSrK.layout())) KIN_PRINT(print(tSrK.layout()))
#endif #endif
Tensor tSgS = thr_mma.partition_C(gP); Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
#if 1 #if 1
KIN_PRINT("acc_o.layout()", print(acc_o.layout())) KIN_PRINT(print(acc_o.layout()))
#endif #endif
// //
...@@ -196,12 +196,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -196,12 +196,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
#if 0 #if 0
KIN_PRINT("fail", smem_thr_copy_Q.print_all()); KIN_PRINT(smem_thr_copy_Q.print_all());
#endif #endif
// if (cute::thread0()) {smem_thr_copy_Q.print_all();} // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
#if 1 #if 1
KIN_PRINT("tSsQ.layout()", print(tSsQ.layout())) KIN_PRINT(print(tSsQ.layout()))
#endif #endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
...@@ -209,7 +209,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -209,7 +209,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSsK = smem_thr_copy_K.partition_S(sK);
#if 1 #if 1
KIN_PRINT("tSsK.layout()", print(tSsK.layout())) KIN_PRINT(print(tSsK.layout()))
#endif #endif
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
...@@ -228,8 +228,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -228,8 +228,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1 #if 1
KIN_PRINT("cQ.layout()", print(cQ.layout())) KIN_PRINT(print(cQ.layout()))
KIN_PRINT("cKV.layout()", print(cKV.layout())) KIN_PRINT(print(cKV.layout()))
#endif #endif
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) { // if (cute::thread0()) {
...@@ -252,10 +252,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -252,10 +252,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK))); Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
#if 1 #if 1
KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) KIN_PRINT(print(tQcQ.layout()))
KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) KIN_PRINT(print(tKVcKV.layout()))
KIN_PRINT("tQpQ.layout()", print(tQpQ.layout())) KIN_PRINT(print(tQpQ.layout()))
KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout())) KIN_PRINT(print(tKVpKV.layout()))
#endif #endif
// Set predicates for k bounds // Set predicates for k bounds
...@@ -537,14 +537,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -537,14 +537,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
#if 1 #if 1
KIN_PRINT("Kernel_traits", print_traits<Kernel_traits>()) KIN_PRINT(print_traits<Kernel_traits>())
KIN_PRINT_BOOL("Is_causal", Is_causal) KIN_PRINT_BOOL(Is_causal)
KIN_PRINT_BOOL("Is_local", Is_local) KIN_PRINT_BOOL(Is_local)
KIN_PRINT_BOOL("Has_alibi", Has_alibi) KIN_PRINT_BOOL(Has_alibi)
KIN_PRINT_BOOL("Is_even_MN", Is_even_MN) KIN_PRINT_BOOL(Is_even_MN)
KIN_PRINT_BOOL("Is_even_K", Is_even_K) KIN_PRINT_BOOL(Is_even_K)
KIN_PRINT_BOOL("Split", Split) KIN_PRINT_BOOL(Split)
KIN_PRINT_BOOL("Append_KV", Append_KV) KIN_PRINT_BOOL(Append_KV)
#endif #endif
using GmemTiledCopyO = std::conditional_t< using GmemTiledCopyO = std::conditional_t<
...@@ -559,7 +559,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -559,7 +559,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
#if 1 #if 1
KIN_PRINT("binfo", print_binfo(binfo)) KIN_PRINT(print_binfo(binfo))
#endif #endif
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
...@@ -649,25 +649,34 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -649,25 +649,34 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
#if 1 #if 1
KIN_PRINT("sK.layout()", print(sK.layout())) KIN_PRINT(print(sK.layout()))
KIN_PRINT("gK.layout()", print(gK.layout())) KIN_PRINT(print(gK.layout()))
KIN_PRINT("sV.layout()", print(sV.layout())) KIN_PRINT(print(sV.layout()))
KIN_PRINT("sVt.layout()", print(sVt.layout())) KIN_PRINT(print(sVt.layout()))
KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout())) KIN_PRINT(print(sVtNoSwizzle.layout()))
#endif #endif
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV;
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
#if 1
KIN_PRINT(print(tKgK.layout()))
KIN_PRINT(print(tKsK.layout()))
#endif
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
#if 1 #if 1
KIN_PRINT("tKgK.layout()", print(tKgK.layout())) fill(tVgV, 1.f * ((Element) tidx));
KIN_PRINT("tKsK.layout()", print(tKsK.layout())) __syncthreads();
KIN_PRINT(print_tensor(gV))
#endif #endif
typename Kernel_traits::TiledMma tiled_mma; typename Kernel_traits::TiledMma tiled_mma;
...@@ -676,13 +685,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -676,13 +685,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
#if 1 #if 1
KIN_PRINT("tSrQ.layout()", print(tSrQ.layout())) KIN_PRINT(print(tSrQ.layout()))
KIN_PRINT("tSrK.layout()", print(tSrK.layout())) KIN_PRINT(print(tSrK.layout()))
#endif #endif
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
#if 1 #if 1
KIN_PRINT("acc_o.layout()", print(acc_o.layout())) KIN_PRINT(print(acc_o.layout()))
#endif #endif
// //
...@@ -693,14 +702,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -693,14 +702,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
#if 1 #if 1
KIN_PRINT("tSsQ.layout()", print(tSsQ.layout())) KIN_PRINT(print(tSsQ.layout()))
#endif #endif
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSsK = smem_thr_copy_K.partition_S(sK);
#if 1 #if 1
KIN_PRINT("tSsK.layout()", print(tSsK.layout())) KIN_PRINT(print(tSsK.layout()))
#endif #endif
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
...@@ -718,22 +727,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -718,22 +727,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1 #if 1
KIN_PRINT("cQ.layout()", print(cQ.layout())) KIN_PRINT(print(cQ.layout()))
KIN_PRINT("cKV.layout()", print(cKV.layout())) KIN_PRINT(print(cKV.layout()))
#endif #endif
// Repeat the partitioning with identity layouts // Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k // Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK))); Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
#if 1 #if 1
KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) KIN_PRINT(print(tQcQ.layout()))
KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) KIN_PRINT(print(tKVcKV.layout()))
KIN_PRINT("tQpQ.layout()", print(tQpQ.layout())) KIN_PRINT(print(tQpQ.layout()))
KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout())) KIN_PRINT(print(tKVpKV.layout()))
#endif #endif
// Set predicates for k bounds // Set predicates for k bounds
...@@ -792,8 +801,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -792,8 +801,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.vnew_row_stride, _1{})); make_stride(params.vnew_row_stride, _1{}));
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) Tensor tKgKnew = gmem_thr_copy_KV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) Tensor tVgVnew = gmem_thr_copy_KV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data(); auto tKgK_data = tKgK.data();
...@@ -853,7 +862,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -853,7 +862,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Read Q from gmem to smem, optionally apply rotary embedding. // Read Q from gmem to smem, optionally apply rotary embedding.
if (!Append_KV || params.rotary_dim == 0) { if (!Append_KV || params.rotary_dim == 0) {
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
} else { } else {
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
...@@ -890,7 +899,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -890,7 +899,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -935,11 +944,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -935,11 +944,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -970,7 +979,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -970,7 +979,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
...@@ -1013,7 +1022,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1013,7 +1022,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm( flash::gemm(
...@@ -1034,7 +1043,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1034,7 +1043,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
} }
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
......
...@@ -131,6 +131,17 @@ struct Flash_fwd_kernel_traits : public Base { ...@@ -131,6 +131,17 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{}, make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
// from how many rows does each thread have to fetch
static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow);
// Here we assign a contiguous tile to each thread, rather than a 1x8 row every
// (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread
// do not cross a page boundary. This way, each thread need only fetch 1 page index per
// mainloop iteration. R>udimentary testing shows no slowdown.
using GmemTiledCopyQKVPaged = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));
using GmemTiledCopyO = decltype( using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment