import tilelang.testing from tilelang import carver from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge from tilelang.carver.arch import auto_infer_current_arch from tvm import te def run_general_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): A = te.placeholder((M, K), name='A', dtype='float16') B = te.placeholder((N, K), name='B', dtype='float16') # Describe the matrix multiplication in TE k = te.reduce_axis((0, K), name='k') C = te.compute( (M, N), lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]), name='C') return A, B, C arg1 = gemm(M, N, K) args = arg1 func = te.create_prim_func(args) tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target) print(tags) policy = carver.TensorCorePolicy.from_prim_func( func=tensorized_func, arch=arch, tags=tags, name="matmul_0") hints = policy.emit_config(topk=topk) for hint in hints: print(hint) assert len(hints) > 0, "Hints length is zero" prim_func_node = PrimFuncNode(tensorized_func, name="matmul_1") output_nodes = [OutputNode(prim_func_node)] policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags) hints = policy.emit_config(topk=10) for config in hints: print(config) assert len(hints) > 0, "Hints length is zero" def test_general_matmul_emit_configs(): run_general_matmul_emit_configs(128, 128, 128) def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): A = te.placeholder((M, K), name='A', dtype='float16') B = te.placeholder((N, K), name='B', dtype='float16') # Describe the matrix multiplication in TE k = te.reduce_axis((0, K), name='k') C = te.compute( (M, N), lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]), name='C') return A, B, C arg1 = gemm(M, N, K) args = arg1 func = te.create_prim_func(args) tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target) print(tags) node_0 = PrimFuncNode(tensorized_func, name="matmul_0") node_1 = PrimFuncNode(tensorized_func, name="matmul_1") edge = Edge(node_0, node_1, 0, 0) node_0._out_edges.append(edge) node_1.set_inputs(0, edge) output_nodes = [OutputNode(node_1)] policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags) hints = policy.emit_config(topk=topk) for config in hints: print(config) assert len(hints) > 0, "Hints length is zero" def test_general_matmul_matmul_emit_configs(): run_general_matmul_matmul_emit_configs(128, 128, 128) if __name__ == "__main__": tilelang.testing.main()