test_tilelang_debug_print.py 3.35 KB
Newer Older
1
2
3
4
5
6
7
# type: ignore

import tilelang
import tilelang.testing
import tilelang.language as T


8
def debug_print_buffer(M=16, N=16, dtype="float16"):
9
    @T.prim_func
10
    def program(Q: T.Tensor((M, N), dtype)):
11
12
13
14
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)
            T.print(shared_buf)

15
    jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi")
16
17
18
19
20
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer():
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    debug_print_buffer(dtype="bool")
    debug_print_buffer(dtype="int8")
    debug_print_buffer(dtype="int16")
    debug_print_buffer(dtype="int32")
    debug_print_buffer(dtype="int64")
    debug_print_buffer(dtype="uint8")
    debug_print_buffer(dtype="uint16")
    debug_print_buffer(dtype="uint32")
    debug_print_buffer(dtype="uint64")
    debug_print_buffer(dtype="float16")
    debug_print_buffer(dtype="float32")
    debug_print_buffer(dtype="float64")
    debug_print_buffer(dtype="bfloat16")
    debug_print_buffer(dtype="float8_e4m3")
    debug_print_buffer(dtype="float8_e4m3fn")
    debug_print_buffer(dtype="float8_e4m3fnuz")
    debug_print_buffer(dtype="float8_e5m2")
    debug_print_buffer(dtype="float8_e5m2fnuz")
39
40
41
42
43
44


def debug_print_buffer_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
45
    def program(Q: T.Tensor((M, N), dtype)):
46
47
48
49
50
51
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)

            if bx == 0 and by == 0 and bz == 0:
                T.print(shared_buf)

52
    jit_kernel = tilelang.compile(program, target="cuda")
53
54
55
56
57
58
59
60
61
62
63
64
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer_conditional():
    debug_print_buffer_conditional(16, 16)


def debug_print_value_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
65
    def program(Q: T.Tensor((M, N), dtype)):
66
67
68
69
70
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz)

71
    jit_kernel = tilelang.compile(program, target="cuda")
72
73
74
75
76
77
78
79
80
81
82
83
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_value_conditional():
    debug_print_value_conditional(16, 16)


def debug_print_register_files(M=16, N=16):
    dtype = "float16"

    @T.prim_func
84
    def program(Q: T.Tensor((M, N), dtype)):
85
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
86
            register_buf = T.alloc_fragment([M, N], dtype)
87
            for i, j in T.Parallel(M, N):
88
                T.print(register_buf[i, j])
89

90
    jit_kernel = tilelang.compile(program, target="cuda")
91
92
93
94
95
96
97
98
99
100
101
102
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_register_files():
    debug_print_register_files(16, 16)


def debug_print_msg(M=16, N=16):
    dtype = "float16"

    @T.prim_func
103
    def program(Q: T.Tensor((M, N), dtype)):
104
105
106
107
108
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz, msg="hello world")

109
    jit_kernel = tilelang.compile(program, target="cuda")
110
111
112
113
114
115
116
117
118
119
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_msg():
    debug_print_msg(16, 16)


if __name__ == "__main__":
    tilelang.testing.main()