/*! * \file tl/op/gemm.cc * * Define gemm operator. */ #include "gemm.h" #include "builtin.h" #include #include #include #include #include "../target/utils.h" namespace tvm { namespace tl { using namespace tir; static std::vector toPrimeFactors(int x) { int i = 2; std::vector result; while (x > 1) { if (x % i == 0) { x /= i; result.push_back(i); } else { i++; } } return result; } Gemm::Gemm(Array args, BufferMap vmap) { Aptr = args[0]; Bptr = args[1]; Cptr = args[2]; A = vmap[GetVarFromAccessPtr(Aptr)]; B = vmap[GetVarFromAccessPtr(Bptr)]; C = vmap[GetVarFromAccessPtr(Cptr)]; trans_A = args[3].as().value(); trans_B = args[4].as().value(); M = args[5].as().value()->value; N = args[6].as().value()->value; K = args[7].as().value()->value; policy = static_cast(args[8].as().value()->value); clear_accum = args[9].as().value(); if (args.size() > 10) { kPack = args[10].as().value()->value; if (kPack != 1 && kPack != 2) { ICHECK(false) << "kPack must be 1 or 2"; } } if (args.size() > 11) { wg_wait = args[11].as().value()->value; } } std::pair Gemm::ComputeWarpPartition(int num_warps, Target target, bool maybe_hopper_wgmma) const { int m_warp = 1, n_warp = 1; bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma && (this->M >= 64) && (num_warps % 4 == 0); if (allow_wgmma) { ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads."; if (this->policy == GemmWarpPolicy::kFullRow || this->policy == GemmWarpPolicy::kSquare) { m_warp = num_warps; n_warp = 1; } else if (this->policy == GemmWarpPolicy::kFullCol) { m_warp = 1; n_warp = num_warps; } else { ICHECK(0) << "Unknown GemmWarpPolicy"; } return {m_warp, n_warp}; } if (this->policy == GemmWarpPolicy::kFullRow) { // Try to partition M first m_warp = num_warps; n_warp = 1; // If M cannot be evenly divided by m_warp*16, try to split remaining warps // to N if (this->M % (m_warp * 16) != 0) { // Calculate how many warps we can use for M int max_m_warps = this->M / 16; m_warp = max_m_warps; // Use remaining warps for N n_warp = num_warps / m_warp; if (n_warp == 0) n_warp = 1; } } else if (this->policy == GemmWarpPolicy::kFullCol) { // Try to partition N first m_warp = 1; n_warp = num_warps; // If N cannot be evenly divided by n_warp*8, try to split remaining warps // to M if (this->N % (n_warp * 8) != 0) { // Calculate how many warps we can use for N int max_n_warps = this->N / 8; n_warp = max_n_warps; // Use remaining warps for M m_warp = num_warps / n_warp; if (m_warp == 0) m_warp = 1; } } else if (this->policy == GemmWarpPolicy::kSquare) { // First calculate the maximum possible warps for each dimension int max_m_warps = this->M / 16; // Each warp needs at least 16 elements in M int max_n_warps = this->N / 8; // Each warp needs at least 8 elements in N // Calculate the ideal ratio of M/N warps based on the matrix dimensions float ideal_ratio = 1.0f; if (this->N > 0) { ideal_ratio = static_cast(this->M) / this->N; } // Start with a balanced initial guess m_warp = 1; n_warp = 1; // Try to find the best balanced partition int best_m = 1; int best_n = 1; float best_balance = std::numeric_limits::max(); // Try all possible combinations that satisfy the constraints for (int m = 1; m <= max_m_warps && m <= num_warps; m++) { int n = num_warps / m; if (n > max_n_warps) continue; if (m * n != num_warps) continue; // Calculate how balanced this partition is float m_per_warp = static_cast(this->M) / (m * 16); float n_per_warp = static_cast(this->N) / (n * 8); float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio); if (balance < best_balance) { best_balance = balance; best_m = m; best_n = n; } } m_warp = best_m; n_warp = best_n; } else { ICHECK(0) << "Unknown GemmWarpPolicy"; } return {m_warp, n_warp}; } Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { int warp_size = 32; if (TargetIsCDNA(T.target)) { warp_size = 64; } auto block_size = *as_const_int(T.thread_bounds->extent); bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) && (block_size / warp_size % 4 == 0); auto [warp_m, warp_n] = ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma); std::stringstream ss; std::string op_name = "tl::gemm_ss"; if (A.scope() == "local.fragment") { ICHECK(B.scope() != "local.fragment"); op_name = "tl::gemm_rs"; } else if (B.scope() == "local.fragment") { op_name = "tl::gemm_sr"; } ss << op_name << "<" << M << ", " << N << ", " << K << ", "; ss << warp_m << ", " << warp_n << ", "; ss << trans_A << ", " << trans_B; ss << ", " << clear_accum; if (TargetIsCDNA(T.target)) { // for cdna gemm, we need to specify kPack ss << ", " << kPack; } else if (TargetIsHopper(T.target)) { ss << ", " << (maybe_wgmma ? "true" : "false"); } if (wg_wait != 0) { ss << ", " << wg_wait; } ss << ">"; auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A; auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B; auto C_buffer = T.buffer_remap[C]; Array new_args; new_args.push_back(StringImm(ss.str())); new_args.push_back(Aptr); new_args.push_back(Bptr); new_args.push_back(Cptr); auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); return Evaluate(new_call); } LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { if (completed_) return {}; LayoutMap results; ICHECK(C.scope() == "local.fragment"); auto thread_range = T.thread_bounds; auto block_size = *as_const_int(thread_range->extent); if (TargetIsVolta(T.target)) { const int warp_size = 32; auto [warp_m, warp_n] = ComputeWarpPartition(block_size / warp_size, T.target); auto fragment = makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), *as_const_int(A->shape[dim_A - 1]), true, trans_A ? 1 : 2)); } else if (A.scope() == "local.fragment") { ICHECK(trans_A == false); auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n); results.Set(A, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn"); int dim_B = B->shape.size(); results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]), *as_const_int(B->shape[dim_B - 1]), false, trans_B ? 2 : 1)); } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) { const int warp_size = 32; auto [warp_m, warp_n] = ComputeWarpPartition(block_size / warp_size, T.target); auto fragment = makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, A->dtype.bits(), trans_A ? 1 : 2)); } else if (A.scope() == "local.fragment") { auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits(), trans_A); results.Set(A, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } if (B.scope() == "shared" || B.scope() == "shared.dyn") { int dim_B = B->shape.size(); const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]); const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, B->dtype.bits(), trans_B ? 2 : 1)); } else if (B.scope() == "local.fragment") { auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); results.Set(B, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } } else if (TargetIsHopper(T.target)) { const int warp_size = 32; bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0); auto [warp_m, warp_n] = ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma); auto fragment = maybe_wgmma ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits()) : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); const int64_t continuity = trans_A ? 4 * mat_continuous / warp_m : mat_continuous; results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, A->dtype.bits(), trans_A ? 1 : 2)); } else { auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits(), trans_A); results.Set(A, fragment->BindThreadRange(thread_range)); } if (B.scope() == "shared" || B.scope() == "shared.dyn") { int dim_B = B->shape.size(); const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]); const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); const int64_t continuity = trans_B ? mat_continuous : mat_continuous / warp_n; results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity, B->dtype.bits(), trans_B ? 2 : 1)); } else { ICHECK(0) << "WGMMA only support B in shared."; } } else if (TargetIsCDNA(T.target)) { const int warp_size = 64; auto [warp_m, warp_n] = ComputeWarpPartition(block_size / warp_size, T.target); auto fragment = makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); results.Set(C, fragment->BindThreadRange(thread_range)); if (A.scope() == "shared" || A.scope() == "shared.dyn") { int dim_A = A->shape.size(); auto shared_layout = makeGemmABLayoutCDNA( *as_const_int(A->shape[dim_A - 2]), *as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack); results.Set(A, shared_layout); } else if (A.scope() == "local.fragment") { auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits(), trans_A); results.Set(A, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } if (B.scope() == "shared" || B.scope() == "shared.dyn") { int dim_B = B->shape.size(); auto shared_layout = makeGemmABLayoutCDNA( *as_const_int(B->shape[dim_B - 2]), *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack); results.Set(B, shared_layout); } else if (B.scope() == "local.fragment") { auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); results.Set(B, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } } else { ICHECK(0) << "Not supported " << T.target->str(); } completed_ = true; return results; } TIR_REGISTER_TL_OP(Gemm, gemm) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); } // namespace tl } // namespace tvm