test_tilelang_debug_print.py 3.35 KB
Newer Older
1
2
3
4
5
6
7
# type: ignore

import tilelang
import tilelang.testing
import tilelang.language as T


8
def debug_print_buffer(M=16, N=16, dtype=T.float16):
9
    @T.prim_func
10
    def program(Q: T.Tensor((M, N), dtype)):
11
12
13
14
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)
            T.print(shared_buf)

15
    jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi")
16
17
18
19
20
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer():
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    debug_print_buffer(dtype=T.bool)
    debug_print_buffer(dtype=T.int8)
    debug_print_buffer(dtype=T.int16)
    debug_print_buffer(dtype=T.int32)
    debug_print_buffer(dtype=T.int64)
    debug_print_buffer(dtype=T.uint8)
    debug_print_buffer(dtype=T.uint16)
    debug_print_buffer(dtype=T.uint32)
    debug_print_buffer(dtype=T.uint64)
    debug_print_buffer(dtype=T.float16)
    debug_print_buffer(dtype=T.float32)
    debug_print_buffer(dtype=T.float64)
    debug_print_buffer(dtype=T.bfloat16)
    debug_print_buffer(dtype=T.float8_e4m3fn)
    debug_print_buffer(dtype=T.float8_e4m3fn)
    debug_print_buffer(dtype=T.float8_e4m3fnuz)
    debug_print_buffer(dtype=T.float8_e5m2)
    debug_print_buffer(dtype=T.float8_e5m2fnuz)
39
40
41


def debug_print_buffer_conditional(M=16, N=16):
42
    dtype = T.float16
43
44

    @T.prim_func
45
    def program(Q: T.Tensor((M, N), dtype)):
46
47
48
49
50
51
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            shared_buf = T.alloc_shared([M, N], dtype)

            if bx == 0 and by == 0 and bz == 0:
                T.print(shared_buf)

52
    jit_kernel = tilelang.compile(program, target="cuda")
53
54
55
56
57
58
59
60
61
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_buffer_conditional():
    debug_print_buffer_conditional(16, 16)


def debug_print_value_conditional(M=16, N=16):
62
    dtype = T.float16
63
64

    @T.prim_func
65
    def program(Q: T.Tensor((M, N), dtype)):
66
67
68
69
70
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz)

71
    jit_kernel = tilelang.compile(program, target="cuda")
72
73
74
75
76
77
78
79
80
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_value_conditional():
    debug_print_value_conditional(16, 16)


def debug_print_register_files(M=16, N=16):
81
    dtype = T.float16
82
83

    @T.prim_func
84
    def program(Q: T.Tensor((M, N), dtype)):
85
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
86
            register_buf = T.alloc_fragment([M, N], dtype)
87
            for i, j in T.Parallel(M, N):
88
                T.print(register_buf[i, j])
89

90
    jit_kernel = tilelang.compile(program, target="cuda")
91
92
93
94
95
96
97
98
99
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_register_files():
    debug_print_register_files(16, 16)


def debug_print_msg(M=16, N=16):
100
    dtype = T.float16
101
102

    @T.prim_func
103
    def program(Q: T.Tensor((M, N), dtype)):
104
105
106
107
108
        with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
            tid = T.get_thread_binding()
            if tid == 0:
                T.print(bx + by + bz, msg="hello world")

109
    jit_kernel = tilelang.compile(program, target="cuda")
110
111
112
113
114
115
116
117
118
119
    profiler = jit_kernel.get_profiler()
    profiler.run_once()


def test_debug_print_msg():
    debug_print_msg(16, 16)


if __name__ == "__main__":
    tilelang.testing.main()