import torch import torch.backends from tilelang import tvm as tvm import tilelang.testing from tvm import DataType import tilelang.language as T from tilelang.intrinsics.utils import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter) tilelang.testing.set_random_seed(0) def make_swizzle_layout(shared_buf): dtype = shared_buf.dtype shape = shared_buf.shape can_swizzle = shape[-1] * DataType(dtype).bits == 512 if not can_swizzle: return T.Layout(shape, lambda *args: args) def transform_func(i, j): new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) return [new_warp_i, new_warp_j] return T.Layout(shape, transform_func) def tl_matmul_macro( N, K, in_dtype, out_dtype, accum_dtype, ): assert in_dtype in [ "float16", "int8", ], "Currently only float16 and int8 are supported" assert out_dtype in [ "float16", "float32", "int32", ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 if out_dtype == "int32": micro_size_k = 32 # This is a debug config block_row_warps = 1 block_col_warps = 1 warp_row_tiles = 16 warp_col_tiles = 16 chunk = 32 if in_dtype == "float16" else 64 shared_scope = "shared.dyn" # Pipeline Stage stage = 2 block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk M = tvm.te.var("m") A_shape = (M, K) B_shape = (N, K) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) C_shared_shape = ( block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y, ) warp_size = 32 threads = warp_size * (block_row_warps * block_col_warps) local_size = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y # MMA Wrapper to Auto Generate Code for MMA mma_emitter = TensorCoreIntrinEmitter( a_dtype=in_dtype, b_dtype=in_dtype, accum_dtype=accum_dtype, a_transposed=False, b_transposed=True, block_row_warps=block_row_warps, block_col_warps=block_col_warps, warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=chunk, ) @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size), in_dtype) B_local = T.alloc_local((warp_cols * local_size), in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) T.annotate_layout({ A_shared: make_swizzle_layout(A_shared), B_shared: make_swizzle_layout(B_shared), }) # Improve L2 Cache T.use_swizzle(panel_size=10) T.clear(C_local) for ko in T.Pipelined((K // block_K), num_stages=stage): # Load A into shared memory for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] # Load B into shared memory for j, k in T.Parallel(block_N, block_K): B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] for ki in T.serial(0, (block_K // micro_size_k)): # Load A into fragment mma_emitter.ldmatrix_a( A_local, A_shared, ki, ) # Load B into fragment mma_emitter.ldmatrix_b( B_local, B_shared, ki, ) # Perform Matrix Multiplication mma_emitter.mma(A_local, B_local, C_local) # Perform STMatrix mma_emitter.stmatrix( C_local, C_shared, ) # Store shared into global for i, j in T.Parallel(block_M, block_N): C[by * block_M + i, bx * block_N + j] = C_shared[ i // micro_size_x, j // micro_size_y, i % micro_size_x, j % micro_size_y, ] return main def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) kernel = tilelang.compile(matmul, out_idx=[2]) src_code = kernel.get_kernel_source() # src_code is the generated cuda source assert src_code is not None A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) C = kernel(A, B) # Get Reference Result ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) def tl_matmul_block( N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads, ): M = tvm.te.var("m") A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( (M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) else: T.copy(A[by * block_M, k * block_K], A_shared) if trans_B: T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main def assert_tl_matmul_block_correctness( M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128, ): program = tl_matmul_block( N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, num_threads, ) kernel = tilelang.compile(program, out_idx=[2]) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) C = kernel(A, B) def ref_program(A, B): import torch if trans_A: A = A.T if trans_B: B = B.T C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C # Get Reference Result ref_c = ref_program(A, B) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) def tl_matmul_block_all_dynamic( block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads, ): M = tvm.te.var("m") N = tvm.te.var("n") K = tvm.te.var("k") A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) @T.prim_func def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( (M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) else: T.copy(A[by * block_M, k * block_K], A_shared) if trans_B: T.copy(B[bx * block_N, k * block_K], B_shared) else: T.copy(B[k * block_K, bx * block_N], B_shared) T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C[by * block_M, bx * block_N]) return main def assert_tl_matmul_block_all_dynamic_correctness( M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128, ): program = tl_matmul_block_all_dynamic( block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, num_threads, ) kernel = tilelang.compile(program) if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B: B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) else: B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) kernel(A, B, C) def ref_program(A, B): import torch if trans_A: A = A.T if trans_B: B = B.T C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C # Get Reference Result ref_c = ref_program(A, B) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128, dynamic_alignment=8, ): program = tl_matmul_block_all_dynamic( block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, num_threads, ) pass_configs = { tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0, tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment } if M % 64 == 0 or N % 64 == 0 or K % 64 != 0: # workaround for hopper tma lower pass pass_configs[tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER] = True pass_configs[tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED] = True kernel = tilelang.compile(program, pass_configs=pass_configs) if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B: B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) else: B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) kernel(A, B, C) def ref_program(A, B): import torch if trans_A: A = A.T if trans_B: B = B.T C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C # Get Reference Result ref_c = ref_program(A, B) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) def test_assert_tl_matmul_macro(): assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(32, 128, 128, "float16", "float16", "float16") def test_assert_tl_matmul_block(): assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic(): assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( 128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( 64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( 64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4) # Tail split is enabled with dynamic alignment 0 assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( 64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0) if __name__ == "__main__": # tilelang.testing.main() assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")