Commit a4049ac8 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

add print statements for debugging

parent 23e8fa5a
#include <cute/util/debug.hpp>
#define KIN_PRINT(tag, statement) \
if (cute::thread0()) { \
printf("[kin:start:%s]\n", tag); \
statement; \
printf("\n[kin:end:%s]\n", tag); \
}
template<typename Kernel_traits>
void
print_traits() {
// bool
printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem );
printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs );
// int
printf("Kernel_traits::kNWarps : %s\n", Kernel_traits::kNWarps );
printf("Kernel_traits::kNThreads : %s\n", Kernel_traits::kNThreads );
printf("Kernel_traits::kBlockM : %s\n", Kernel_traits::kBlockM );
printf("Kernel_traits::kBlockN : %s\n", Kernel_traits::kBlockN );
printf("Kernel_traits::kHeadDim : %s\n", Kernel_traits::kHeadDim );
printf("Kernel_traits::kBlockKSmem : %s\n", Kernel_traits::kBlockKSmem );
printf("Kernel_traits::kBlockKGmem : %s\n", Kernel_traits::kBlockKGmem );
printf("Kernel_traits::kSwizzle : %s\n", Kernel_traits::kSwizzle );
printf("Kernel_traits::kSmemQSize : %s\n", Kernel_traits::kSmemQSize );
printf("Kernel_traits::kSmemKVSize : %s\n", Kernel_traits::kSmemKVSize );
printf("Kernel_traits::kSmemSize : %s\n", Kernel_traits::kSmemSize );
printf("Kernel_traits::kGmemElemsPerLoad : %s\n", Kernel_traits::kGmemElemsPerLoad );
// cute object
printf("Kernel_traits::GmemLayoutAtom : "); print(Kernel_traits::GmemLayoutAtom); printf("\n");
printf("Kernel_traits::GmemTiledCopyQKV : "); print(Kernel_traits::GmemTiledCopyQKV); printf("\n");
printf("Kernel_traits::GmemTiledCopyO : "); print(Kernel_traits::GmemTiledCopyO); printf("\n");
printf("Kernel_traits::SmemCopyAtom : "); print(Kernel_traits::SmemCopyAtom); printf("\n");
printf("Kernel_traits::SmemCopyAtomTransposed : "); print(Kernel_traits::SmemCopyAtomTransposed); printf("\n");
printf("Kernel_traits::MMA_Atom_Arch : "); print(Kernel_traits::MMA_Atom_Arch); printf("\n");
}
......@@ -18,6 +18,8 @@
#include "dropout.h"
#include "rotary.h"
#include "debug.h"
namespace flash {
using namespace cute;
......@@ -41,6 +43,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
#if 0
KIN_PRINT("Kernel_traits", print_traits<Kernel_traits>());
#endif
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
......@@ -55,6 +60,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
#if 0
// const int sum_s_q;
// const int sum_s_k;
// const int actual_seqlen_q;
// const int seqlen_k_cache;
// const int actual_seqlen_k;
KIN_PRINT("binfo.sum_s_q", printf("%d", binfo.sum_s_q))
KIN_PRINT("binfo.sum_s_k", printf("%d", binfo.sum_s_k))
KIN_PRINT("binfo.actual_seqlen_q", printf("%d", binfo.actual_seqlen_q))
KIN_PRINT("binfo.seqlen_k_cache", printf("%d", binfo.seqlen_k_cache))
KIN_PRINT("binfo.actual_seqlen_k", printf("%d", binfo.actual_seqlen_k))
#endif
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
......@@ -136,10 +153,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
#if 1
KIN_PRINT("sK.layout()", print(sK.layout()))
KIN_PRINT("gK.layout()", print(gK.layout()))
KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem))
#endif
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
#if 1
KIN_PRINT("sV.layout()", print(sV.layout()))
KIN_PRINT("sVt.layout()", print(sVt.layout()))
KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout()))
KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem))
#endif
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
......@@ -150,16 +181,30 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
#if 1
KIN_PRINT("tKgK.layout()", print(tKgK.layout()))
KIN_PRINT("tKsK.layout()", print(tKsK.layout()))
#endif
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT("tSrQ.layout()", print(tSrQ.layout()))
KIN_PRINT("tSrK.layout()", print(tSrK.layout()))
#endif
Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
#if 1
KIN_PRINT("acc_o.layout()", print(acc_o.layout()))
#endif
//
// Copy Atom retiling
//
......@@ -168,11 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
#if 1
KIN_PRINT("smem_thr_copy_Q.print_all()", smem_thr_copy_Q.print_all())
KIN_PRINT("tSsQ.layout()", print(tSsQ.layout()))
#endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
# if 1
KIN_PRINT("tSsK.layout()", print(tSsK.layout()))
#endif
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
......@@ -189,6 +241,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT("cQ.layout()", print(cQ.layout()))
KIN_PRINT("cKV.layout()", print(cKV.layout()))
#endif
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// print(tScQ.layout()); printf("\n");
......@@ -205,10 +261,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT("tQcQ.layout()", print(tQcQ.layout()))
KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout()))
#endif
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
#if 1
KIN_PRINT("tQpQ.layout()", print(tQpQ.layout()))
KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout()))
#endif
// Set predicates for k bounds
if (!Is_even_K) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment