debug.h 2.49 KB
Newer Older
skrider's avatar
skrider committed
1
#include <cute/util/debug.hpp>
skrider's avatar
skrider committed
2
3
4
#include "block_info.h"

#pragma once
skrider's avatar
skrider committed
5

skrider's avatar
skrider committed
6
#define KIN_PRINT(statement) \
skrider's avatar
skrider committed
7
    if (thread0()) { \
skrider's avatar
skrider committed
8
        printf("\n[kin:start:%s]\n", #statement); \
skrider's avatar
skrider committed
9
        statement; \
skrider's avatar
skrider committed
10
        printf("\n[kin:end:%s]\n", #statement); \
skrider's avatar
skrider committed
11
12
    }

skrider's avatar
skrider committed
13
#define KIN_PRINT_BOOL(BOOL) \
skrider's avatar
skrider committed
14
    if (thread0()) { \
skrider's avatar
skrider committed
15
        printf("\n[kin:start:%s]\n", #BOOL); \
skrider's avatar
skrider committed
16
        printf("%s", BOOL ? "true" : "false"); \
skrider's avatar
skrider committed
17
        printf("\n[kin:end:%s]\n", #BOOL); \
skrider's avatar
skrider committed
18
19
    }

skrider's avatar
skrider committed
20
template<typename Kernel_traits>
skrider's avatar
skrider committed
21
__forceinline__ __device__ void
skrider's avatar
skrider committed
22
23
print_traits() {
    // bool
skrider's avatar
skrider committed
24
25
    printf("Kernel_traits::Share_Q_K_smem    : %s\n", Kernel_traits::Share_Q_K_smem ? "true" : "false");
    printf("Kernel_traits::Is_Q_in_regs      : %s\n", Kernel_traits::Is_Q_in_regs ? "true" : "false");
skrider's avatar
skrider committed
26
27

    // int
skrider's avatar
skrider committed
28
29
30
31
32
33
34
35
36
37
38
    printf("Kernel_traits::kNWarps           : %d\n", Kernel_traits::kNWarps );
    printf("Kernel_traits::kNThreads         : %d\n", Kernel_traits::kNThreads );
    printf("Kernel_traits::kBlockM           : %d\n", Kernel_traits::kBlockM );
    printf("Kernel_traits::kBlockN           : %d\n", Kernel_traits::kBlockN );
    printf("Kernel_traits::kHeadDim          : %d\n", Kernel_traits::kHeadDim );
    printf("Kernel_traits::kBlockKSmem       : %d\n", Kernel_traits::kBlockKSmem );
    printf("Kernel_traits::kBlockKGmem       : %d\n", Kernel_traits::kBlockKGmem );
    printf("Kernel_traits::kSwizzle          : %d\n", Kernel_traits::kSwizzle );
    printf("Kernel_traits::kSmemQSize        : %d\n", Kernel_traits::kSmemQSize );
    printf("Kernel_traits::kSmemKVSize       : %d\n", Kernel_traits::kSmemKVSize );
    printf("Kernel_traits::kSmemSize         : %d\n", Kernel_traits::kSmemSize );
skrider's avatar
skrider committed
39
    printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread );
skrider's avatar
skrider committed
40
    printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad );
skrider's avatar
skrider committed
41
42
43
44
45
46
47
48
49

    // cute object
    printf("Kernel_traits::GmemLayoutAtom    : ");
    cute::print(Kernel_traits::GmemLayoutAtom());
    printf("\n");
    printf("Kernel_traits::GmemTiledCopyQKV  :\n");
    cute::print(Kernel_traits::GmemTiledCopyQKV());
    printf("\n");
    
skrider's avatar
skrider committed
50
51
}

skrider's avatar
skrider committed
52
53
54
55
56
57
58
59
60
template<typename BlockInfo>
__forceinline__ __device__ void
print_binfo(const BlockInfo& binfo) {
    printf("binfo.sum_s_q           : %d\n", binfo.sum_s_q);
    printf("binfo.sum_s_k           : %d\n", binfo.sum_s_k);
    printf("binfo.actual_seqlen_q   : %d\n", binfo.actual_seqlen_q);
    printf("binfo.seqlen_k_cache    : %d\n", binfo.seqlen_k_cache);
    printf("binfo.actual_seqlen_k   : %d\n", binfo.actual_seqlen_k);
}