# type: ignore import tilelang import tilelang.testing import tilelang.language as T def debug_print_buffer(M=16, N=16, dtype="float16"): @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): shared_buf = T.alloc_shared([M, N], dtype) T.print(shared_buf) jit_kernel = tilelang.compile(program, target="cuda", execution_backend="tvm_ffi") profiler = jit_kernel.get_profiler() profiler.run_once() def test_debug_print_buffer(): debug_print_buffer(dtype="bool") debug_print_buffer(dtype="int8") debug_print_buffer(dtype="int16") debug_print_buffer(dtype="int32") debug_print_buffer(dtype="int64") debug_print_buffer(dtype="uint8") debug_print_buffer(dtype="uint16") debug_print_buffer(dtype="uint32") debug_print_buffer(dtype="uint64") debug_print_buffer(dtype="float16") debug_print_buffer(dtype="float32") debug_print_buffer(dtype="float64") debug_print_buffer(dtype="bfloat16") debug_print_buffer(dtype="float8_e4m3") debug_print_buffer(dtype="float8_e4m3fn") debug_print_buffer(dtype="float8_e4m3fnuz") debug_print_buffer(dtype="float8_e5m2") debug_print_buffer(dtype="float8_e5m2fnuz") def debug_print_buffer_conditional(M=16, N=16): dtype = "float16" @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): shared_buf = T.alloc_shared([M, N], dtype) if bx == 0 and by == 0 and bz == 0: T.print(shared_buf) jit_kernel = tilelang.compile(program, target="cuda") profiler = jit_kernel.get_profiler() profiler.run_once() def test_debug_print_buffer_conditional(): debug_print_buffer_conditional(16, 16) def debug_print_value_conditional(M=16, N=16): dtype = "float16" @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): tid = T.get_thread_binding() if tid == 0: T.print(bx + by + bz) jit_kernel = tilelang.compile(program, target="cuda") profiler = jit_kernel.get_profiler() profiler.run_once() def test_debug_print_value_conditional(): debug_print_value_conditional(16, 16) def debug_print_register_files(M=16, N=16): dtype = "float16" @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): register_buf = T.alloc_fragment([M, N], dtype) for i, j in T.Parallel(M, N): T.print(register_buf[i, j]) jit_kernel = tilelang.compile(program, target="cuda") profiler = jit_kernel.get_profiler() profiler.run_once() def test_debug_print_register_files(): debug_print_register_files(16, 16) def debug_print_msg(M=16, N=16): dtype = "float16" @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): tid = T.get_thread_binding() if tid == 0: T.print(bx + by + bz, msg="hello world") jit_kernel = tilelang.compile(program, target="cuda") profiler = jit_kernel.get_profiler() profiler.run_once() def test_debug_print_msg(): debug_print_msg(16, 16) if __name__ == "__main__": tilelang.testing.main()