Commit ac5e78a6 authored by skrider's avatar skrider
Browse files

add print statements for debugging

parent 8efeb7f5
#include <cute/util/debug.hpp>
#include "block_info.h"
#pragma once
#define KIN_PRINT(tag, statement) \
if (cute::thread0()) { \
printf("[kin:start:%s]\n", tag); \
if (thread0()) { \
printf("\n[kin:start:%s]\n", tag); \
statement; \
printf("\n[kin:end:%s]\n", tag); \
}
#define KIN_PRINT_BOOL(tag, BOOL) \
if (thread0()) { \
printf("\n[kin:start:%s]\n", tag); \
printf("%s", BOOL ? "true" : "false"); \
printf("\n[kin:end:%s]\n", tag); \
}
template<typename Kernel_traits>
void
__forceinline__ __device__ void
print_traits() {
// bool
printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem );
printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs );
printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem ? "true" : "false");
printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs ? "true" : "false");
// int
printf("Kernel_traits::kNWarps : %s\n", Kernel_traits::kNWarps );
printf("Kernel_traits::kNThreads : %s\n", Kernel_traits::kNThreads );
printf("Kernel_traits::kBlockM : %s\n", Kernel_traits::kBlockM );
printf("Kernel_traits::kBlockN : %s\n", Kernel_traits::kBlockN );
printf("Kernel_traits::kHeadDim : %s\n", Kernel_traits::kHeadDim );
printf("Kernel_traits::kBlockKSmem : %s\n", Kernel_traits::kBlockKSmem );
printf("Kernel_traits::kBlockKGmem : %s\n", Kernel_traits::kBlockKGmem );
printf("Kernel_traits::kSwizzle : %s\n", Kernel_traits::kSwizzle );
printf("Kernel_traits::kSmemQSize : %s\n", Kernel_traits::kSmemQSize );
printf("Kernel_traits::kSmemKVSize : %s\n", Kernel_traits::kSmemKVSize );
printf("Kernel_traits::kSmemSize : %s\n", Kernel_traits::kSmemSize );
printf("Kernel_traits::kGmemElemsPerLoad : %s\n", Kernel_traits::kGmemElemsPerLoad );
// cute object
printf("Kernel_traits::GmemLayoutAtom : "); print(Kernel_traits::GmemLayoutAtom); printf("\n");
printf("Kernel_traits::GmemTiledCopyQKV : "); print(Kernel_traits::GmemTiledCopyQKV); printf("\n");
printf("Kernel_traits::GmemTiledCopyO : "); print(Kernel_traits::GmemTiledCopyO); printf("\n");
printf("Kernel_traits::SmemCopyAtom : "); print(Kernel_traits::SmemCopyAtom); printf("\n");
printf("Kernel_traits::SmemCopyAtomTransposed : "); print(Kernel_traits::SmemCopyAtomTransposed); printf("\n");
printf("Kernel_traits::MMA_Atom_Arch : "); print(Kernel_traits::MMA_Atom_Arch); printf("\n");
printf("Kernel_traits::kNWarps : %d\n", Kernel_traits::kNWarps );
printf("Kernel_traits::kNThreads : %d\n", Kernel_traits::kNThreads );
printf("Kernel_traits::kBlockM : %d\n", Kernel_traits::kBlockM );
printf("Kernel_traits::kBlockN : %d\n", Kernel_traits::kBlockN );
printf("Kernel_traits::kHeadDim : %d\n", Kernel_traits::kHeadDim );
printf("Kernel_traits::kBlockKSmem : %d\n", Kernel_traits::kBlockKSmem );
printf("Kernel_traits::kBlockKGmem : %d\n", Kernel_traits::kBlockKGmem );
printf("Kernel_traits::kSwizzle : %d\n", Kernel_traits::kSwizzle );
printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize );
printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize );
printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize );
printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad );
}
template<typename BlockInfo>
__forceinline__ __device__ void
print_binfo(const BlockInfo& binfo) {
printf("binfo.sum_s_q : %d\n", binfo.sum_s_q);
printf("binfo.sum_s_k : %d\n", binfo.sum_s_k);
printf("binfo.actual_seqlen_q : %d\n", binfo.actual_seqlen_q);
printf("binfo.seqlen_k_cache : %d\n", binfo.seqlen_k_cache);
printf("binfo.actual_seqlen_k : %d\n", binfo.actual_seqlen_k);
}
......@@ -43,7 +43,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
#if 0
#if 1
KIN_PRINT("Kernel_traits", print_traits<Kernel_traits>());
#endif
......@@ -60,17 +60,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
#if 0
// const int sum_s_q;
// const int sum_s_k;
// const int actual_seqlen_q;
// const int seqlen_k_cache;
// const int actual_seqlen_k;
KIN_PRINT("binfo.sum_s_q", printf("%d", binfo.sum_s_q))
KIN_PRINT("binfo.sum_s_k", printf("%d", binfo.sum_s_k))
KIN_PRINT("binfo.actual_seqlen_q", printf("%d", binfo.actual_seqlen_q))
KIN_PRINT("binfo.seqlen_k_cache", printf("%d", binfo.seqlen_k_cache))
KIN_PRINT("binfo.actual_seqlen_k", printf("%d", binfo.actual_seqlen_k))
#if 1
KIN_PRINT("binfo", print_binfo(binfo))
#endif
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
......@@ -153,22 +144,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
#if 1
KIN_PRINT("sK.layout()", print(sK.layout()))
KIN_PRINT("gK.layout()", print(gK.layout()))
KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem))
#endif
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
#if 1
KIN_PRINT("sV.layout()", print(sV.layout()))
KIN_PRINT("sVt.layout()", print(sVt.layout()))
KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout()))
KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem))
#endif
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
......@@ -180,7 +167,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
#if 1
KIN_PRINT("tKgK.layout()", print(tKgK.layout()))
KIN_PRINT("tKsK.layout()", print(tKsK.layout()))
......@@ -191,7 +177,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT("tSrQ.layout()", print(tSrQ.layout()))
KIN_PRINT("tSrK.layout()", print(tSrK.layout()))
......@@ -200,7 +185,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
#if 1
KIN_PRINT("acc_o.layout()", print(acc_o.layout()))
#endif
......@@ -211,10 +195,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
#if 0
KIN_PRINT("fail", smem_thr_copy_Q.print_all());
#endif
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
#if 1
KIN_PRINT("smem_thr_copy_Q.print_all()", smem_thr_copy_Q.print_all())
KIN_PRINT("tSsQ.layout()", print(tSsQ.layout()))
#endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
......@@ -222,7 +208,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
# if 1
#if 1
KIN_PRINT("tSsK.layout()", print(tSsK.layout()))
#endif
......@@ -261,15 +247,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT("tQcQ.layout()", print(tQcQ.layout()))
KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout()))
#endif
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
#if 1
KIN_PRINT("tQcQ.layout()", print(tQcQ.layout()))
KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout()))
KIN_PRINT("tQpQ.layout()", print(tQpQ.layout()))
KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout()))
#endif
......@@ -552,6 +536,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
#if 1
KIN_PRINT("Kernel_traits", print_traits<Kernel_traits>())
KIN_PRINT_BOOL("Is_causal", Is_causal)
KIN_PRINT_BOOL("Is_local", Is_local)
KIN_PRINT_BOOL("Has_alibi", Has_alibi)
KIN_PRINT_BOOL("Is_even_MN", Is_even_MN)
KIN_PRINT_BOOL("Is_even_K", Is_even_K)
KIN_PRINT_BOOL("Split", Split)
KIN_PRINT_BOOL("Append_KV", Append_KV)
#endif
using GmemTiledCopyO = std::conditional_t<
!Split,
......@@ -564,6 +558,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
#if 1
KIN_PRINT("binfo", print_binfo(binfo))
#endif
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
const int n_block_min = !Is_local
......@@ -645,13 +642,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
#if 1
KIN_PRINT("sK.layout()", print(sK.layout()))
KIN_PRINT("gK.layout()", print(gK.layout()))
KIN_PRINT("sV.layout()", print(sV.layout()))
KIN_PRINT("sVt.layout()", print(sVt.layout()))
KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout()))
#endif
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
......@@ -662,14 +665,25 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
#if 1
KIN_PRINT("tKgK.layout()", print(tKgK.layout()))
KIN_PRINT("tKsK.layout()", print(tKsK.layout()))
#endif
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT("tSrQ.layout()", print(tSrQ.layout()))
KIN_PRINT("tSrK.layout()", print(tSrK.layout()))
#endif
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
#if 1
KIN_PRINT("acc_o.layout()", print(acc_o.layout()))
#endif
//
// Copy Atom retiling
......@@ -678,10 +692,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
#if 1
KIN_PRINT("tSsQ.layout()", print(tSsQ.layout()))
#endif
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
#if 1
KIN_PRINT("tSsK.layout()", print(tSsK.layout()))
#endif
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
......@@ -697,6 +717,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT("cQ.layout()", print(cQ.layout()))
KIN_PRINT("cKV.layout()", print(cKV.layout()))
#endif
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
......@@ -705,6 +729,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
#if 1
KIN_PRINT("tQcQ.layout()", print(tQcQ.layout()))
KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout()))
KIN_PRINT("tQpQ.layout()", print(tQpQ.layout()))
KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout()))
#endif
// Set predicates for k bounds
if (!Is_even_K) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment