test_tilelang_debug_print.py 2.74 KB
Newer Older
1
2
3
4
5
6
7
# type: ignore

import tilelang
import tilelang.testing
import tilelang.language as T


8
def debug_print_buffer(M=16, N=16, dtype="float16"):
9
10

    @T.prim_func
11
    def program(Q: T.Tensor((M, N), dtype)):
12
13
14
15
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)
            T.print(shared_buf)

16
    jit_kernel = tilelang.compile(program, target="cuda")
17
18
19
20
21
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer():
22
23
24
    debug_print_buffer(16, 16, dtype="float")
    debug_print_buffer(16, 16, dtype="float16")
    debug_print_buffer(16, 16, dtype="uint8")
25
26
27
28
29
30


def debug_print_buffer_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
31
    def program(Q: T.Tensor((M, N), dtype)):
32
33
34
35
36
37
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)

            if bx == 0 and by == 0 and bz == 0:
                T.print(shared_buf)

38
    jit_kernel = tilelang.compile(program, target="cuda")
39
40
41
42
43
44
45
46
47
48
49
50
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer_conditional():
    debug_print_buffer_conditional(16, 16)


def debug_print_value_conditional(M=16, N=16):
    dtype = "float16"

    @T.prim_func
51
    def program(Q: T.Tensor((M, N), dtype)):
52
53
54
55
56
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz)

57
    jit_kernel = tilelang.compile(program, target="cuda")
58
59
60
61
62
63
64
65
66
67
68
69
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_value_conditional():
    debug_print_value_conditional(16, 16)


def debug_print_register_files(M=16, N=16):
    dtype = "float16"

    @T.prim_func
70
    def program(Q: T.Tensor((M, N), dtype)):
71
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
72
            register_buf = T.alloc_fragment([M, N], dtype)
73
            for i, j in T.Parallel(M, N):
74
                T.print(register_buf[i, j])
75

76
    jit_kernel = tilelang.compile(program, target="cuda")
77
78
79
80
81
82
83
84
85
86
87
88
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_register_files():
    debug_print_register_files(16, 16)


def debug_print_msg(M=16, N=16):
    dtype = "float16"

    @T.prim_func
89
    def program(Q: T.Tensor((M, N), dtype)):
90
91
92
93
94
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz, msg="hello world")

95
    jit_kernel = tilelang.compile(program, target="cuda")
96
97
98
99
100
101
102
103
104
105
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_msg():
    debug_print_msg(16, 16)


if __name__ == "__main__":
    tilelang.testing.main()