test_tilelang_debug_print.py 3.35 KB
Newer Older
1
2
3
4
5
6
7
# type: ignore

import tilelang
import tilelang.testing
import tilelang.language as T


8
def debug_print_buffer(M=16, N=16, dtype="float16"):
9
10

    @T.prim_func
11
    def program(Q: T.Tensor((M, N), dtype)):
12
13
14
15
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)
            T.print(shared_buf)

16
    jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi")
17
18
19
20
21
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer():
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    debug_print_buffer(dtype='bool')
    debug_print_buffer(dtype='int8')
    debug_print_buffer(dtype='int16')
    debug_print_buffer(dtype='int32')
    debug_print_buffer(dtype='int64')
    debug_print_buffer(dtype='uint8')
    debug_print_buffer(dtype='uint16')
    debug_print_buffer(dtype='uint32')
    debug_print_buffer(dtype='uint64')
    debug_print_buffer(dtype='float16')
    debug_print_buffer(dtype='float32')
    debug_print_buffer(dtype='float64')
    debug_print_buffer(dtype='bfloat16')
    debug_print_buffer(dtype='float8_e4m3')
    debug_print_buffer(dtype='float8_e4m3fn')
    debug_print_buffer(dtype='float8_e4m3fnuz')
    debug_print_buffer(dtype='float8_e5m2')
    debug_print_buffer(dtype='float8_e5m2fnuz')
40
41
42
43
44
45


def debug_print_buffer_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
46
    def program(Q: T.Tensor((M, N), dtype)):
47
48
49
50
51
52
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)

            if bx == 0 and by == 0 and bz == 0:
                T.print(shared_buf)

53
    jit_kernel = tilelang.compile(program, target="cuda")
54
55
56
57
58
59
60
61
62
63
64
65
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer_conditional():
    debug_print_buffer_conditional(16, 16)


def debug_print_value_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
66
    def program(Q: T.Tensor((M, N), dtype)):
67
68
69
70
71
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz)

72
    jit_kernel = tilelang.compile(program, target="cuda")
73
74
75
76
77
78
79
80
81
82
83
84
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_value_conditional():
    debug_print_value_conditional(16, 16)


def debug_print_register_files(M=16, N=16):
    dtype = "float16"

    @T.prim_func
85
    def program(Q: T.Tensor((M, N), dtype)):
86
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
87
            register_buf = T.alloc_fragment([M, N], dtype)
88
            for i, j in T.Parallel(M, N):
89
                T.print(register_buf[i, j])
90

91
    jit_kernel = tilelang.compile(program, target="cuda")
92
93
94
95
96
97
98
99
100
101
102
103
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_register_files():
    debug_print_register_files(16, 16)


def debug_print_msg(M=16, N=16):
    dtype = "float16"

    @T.prim_func
104
    def program(Q: T.Tensor((M, N), dtype)):
105
106
107
108
109
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz, msg="hello world")

110
    jit_kernel = tilelang.compile(program, target="cuda")
111
112
113
114
115
116
117
118
119
120
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_msg():
    debug_print_msg(16, 16)


if __name__ == "__main__":
    tilelang.testing.main()