from .gemm_base import GemmBase from tilelang.layout import make_swizzled_layout from tilelang.intrinsics.mfma_macro_generator import ( MatrixCoreIntrinEmitter,) from tilelang.utils.language import is_shared, is_fragment from tilelang import tvm as tvm from tvm.target import Target from tvm import tir from tilelang import language as T from tilelang.transform.simplify import _Simplify class GemmMFMA(GemmBase): def infer_layout(self, target: Target, thread_nums: int): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mfma_emitter = MatrixCoreIntrinEmitter( a_dtype=self.in_dtype, b_dtype=self.in_dtype, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, block_row_warps=m_warp, block_col_warps=n_warp, warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=self.chunk, ) if self.is_gemm_ss(): return { self.A: make_swizzled_layout(self.A), self.B: make_swizzled_layout(self.B), self.C: mfma_emitter.make_mfma_store_layout(self.C), } elif self.is_gemm_sr(): return { self.A: make_swizzled_layout(self.A), self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), self.C: mfma_emitter.make_mfma_store_layout(self.C), } elif self.is_gemm_rs(): return { self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), self.B: make_swizzled_layout(self.B), self.C: mfma_emitter.make_mfma_store_layout(self.C), } elif self.is_gemm_rr(): return { self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"), self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"), self.C: mfma_emitter.make_mfma_store_layout(self.C), } else: raise ValueError( f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) mfma_emitter = MatrixCoreIntrinEmitter( a_dtype=self.in_dtype, b_dtype=self.in_dtype, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, block_row_warps=m_warp, block_col_warps=n_warp, warp_row_tiles=warp_row_tiles, warp_col_tiles=warp_col_tiles, chunk=self.chunk, thread_var=thread_var, ) in_dtype = self.in_dtype warp_rows = mfma_emitter.warp_rows warp_cols = mfma_emitter.warp_cols local_size_a = mfma_emitter.local_size_a local_size_b = mfma_emitter.local_size_b block_K = mfma_emitter.chunk micro_size_k = mfma_emitter.micro_size_k A_shared = self.A B_shared = self.B C_local = self.C assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" if self.is_gemm_ss(): @T.prim_func def _gemm_ssr() -> None: """ The inner macro that loads data from shared buffers A_shared and B_shared into local fragments, then issues Matrix Core mfma ops, accumulating into C_local. """ A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) for ki in T.serial(0, (block_K // micro_size_k)): # Load A into fragment mfma_emitter.ldmatrix_a( A_local, A_shared, ki, ) # Load B into fragment mfma_emitter.ldmatrix_b( B_local, B_shared, ki, ) # Perform Matrix Multiplication mfma_emitter.mfma(A_local, B_local, C_local, ki) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis return _Simplify(_gemm_ssr, inline_let=True) elif self.is_gemm_sr(): B_local = self.B @T.prim_func def _gemm_srr() -> None: """ The inner macro that loads data from shared buffers A_shared and B_shared into local fragments, then issues Matrix Core mfma ops, accumulating into C_local. """ A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) for ki in T.serial(0, (block_K // micro_size_k)): # Load A into fragment mfma_emitter.ldmatrix_a( A_local, A_shared, ki, ) # Perform Matrix Multiplication mfma_emitter.mfma(A_local, B_local, C_local, ki) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis # alloc_buffers body # insert into parent block return _Simplify(_gemm_srr, inline_let=True) elif self.is_gemm_rs(): A_local = self.A @T.prim_func def _gemm_rsr() -> None: """ The inner macro that loads data from shared buffers A_shared and B_shared into local fragments, then issues Matrix Core mfma ops, accumulating into C_local. """ B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) for ki in T.serial(0, (block_K // micro_size_k)): # Load B into fragment mfma_emitter.ldmatrix_b( B_local, B_shared, ki, ) # Perform Matrix Multiplication mfma_emitter.mfma(A_local, B_local, C_local, ki) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) elif self.is_gemm_rr(): A_local = self.A B_local = self.B @T.prim_func def _gemm_rsr() -> None: """ The inner macro that loads data from shared buffers A_shared and B_shared into local fragments, then issues Matrix Core mfma ops, accumulating into C_local. """ for ki in T.serial(0, (block_K // micro_size_k)): # Perform Matrix Multiplication mfma_emitter.mfma(A_local, B_local, C_local, ki) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis return _Simplify(_gemm_rsr, inline_let=True) else: raise ValueError( f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") def is_gemm_ss(self) -> bool: return is_shared(self.A) and is_shared(self.B) def is_gemm_sr(self) -> bool: return is_shared(self.A) and is_fragment(self.B) def is_gemm_rs(self) -> bool: return is_fragment(self.A) and is_shared(self.B) def is_gemm_rr(self) -> bool: return is_fragment(self.A) and is_fragment(self.B)