Unverified Commit 4a9cb470 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[SM70] Refactor and minor fix for SM70 (#1195)

* [Feature] Add support for SM70 tensor core MMA instructions

- Introduced new intrinsic `ptx_mma_sm70` for Volta GPUs, enabling m16n16k4 shape with FP16 inputs and FP16/FP32 accumulation.
- Added `GemmMMASm70` class for handling GEMM operations specific to SM70 architecture.
- Implemented layout functions for Volta swizzled layouts and updated existing GEMM layout inference logic.
- Updated `requirements-dev.txt` to include `apache-tvm-ffi` dependency.
- Added correctness evaluation script for testing GEMM operations on SM70.

* [Refactor] Update formatting and installation commands in scripts

- Modified `format.sh` to install `pre-commit` and `clang-tidy` with the `--user` flag for user-specific installations.
- Improved readability in `correctness_evaluation_sm70.py` by adjusting the formatting of pytest parameters.
- Cleaned up spacing and formatting in various C++ source files for better consistency and readability.
- Removed unnecessary comments and improved layout function definitions in `mma_sm70_layout.py` and `mma_sm70_macro_generator.py` for clarity.
- Ensured consistent formatting in layout initialization and swizzle functions.

* typo fix
parent 11456de2
...@@ -85,7 +85,7 @@ export PIP_USER=0 ...@@ -85,7 +85,7 @@ export PIP_USER=0
# If pre-commit is not installed, install it. # If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit python3 -m pip install pre-commit --user
fi fi
echo 'tile-lang pre-commit: Check Start' echo 'tile-lang pre-commit: Check Start'
...@@ -115,7 +115,7 @@ echo 'tile-lang clang-tidy: Check Start' ...@@ -115,7 +115,7 @@ echo 'tile-lang clang-tidy: Check Start'
if [[ -x "$(command -v run-clang-tidy)" ]]; then if [[ -x "$(command -v run-clang-tidy)" ]]; then
# Check if clang-tidy is available # Check if clang-tidy is available
if [[ ! -x "$(command -v clang-tidy)" ]]; then if [[ ! -x "$(command -v clang-tidy)" ]]; then
python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" --user
fi fi
# Get clang-tidy version # Get clang-tidy version
CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')" CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')"
......
# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [64, 128]
N_VALUES = [16, 32, 64, 128]
K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
] + [
pytest.param(
k,
"float16",
"float16",
"float32",
id=f"K{k}-float16-float16-float32",
) for k in K_VALUES
])
def _ensure_torch_dtypes(*dtype_names):
import torch
for name in set(dtype_names):
if not hasattr(torch, name):
pytest.skip(f"Torch does not expose dtype {name}")
def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128)
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
TRANS_CASES = [
pytest.param(False, False, id="nn"),
pytest.param(False, True, id="nt"),
pytest.param(True, False, id="tn"),
pytest.param(True, True, id="tt"),
]
@pytest.fixture(scope="module", autouse=True)
def _setup_tilelang_environment():
tilelang.disable_cache()
tilelang.testing.set_random_seed(42)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_false_false(m, n, k):
run_gemm(
m,
n,
k * 3,
False,
False,
"float16",
"float16",
"float16",
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
_ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype)
run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rs_false_false(m, n, k)
if __name__ == "__main__":
tilelang.testing.main()
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
...@@ -63,7 +63,7 @@ N = 16384 ...@@ -63,7 +63,7 @@ N = 16384
K = 16384 K = 16384
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module # 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
......
# Requirements to run local build with `--no-build-isolation` or other developments # Requirements to run local build with `--no-build-isolation` or other developments
apache-tvm-ffi~=0.1.0
build build
cmake>=3.26 cmake>=3.26
cython>=3.0.0 cython>=3.0.0
......
...@@ -577,11 +577,11 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) { ...@@ -577,11 +577,11 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
bool k_inner) { bool k_inner) {
if (k_inner) if (k_inner && continuous % 32 == 0 && stride % 32 == 0)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous); return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0) if (is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous); return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0) if (!is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaBLayoutCongruous(stride, continuous); return MakeGemmVoltaBLayoutCongruous(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, 16); return makeGemmABLayoutPadded(stride, continuous, 16);
} }
......
...@@ -540,6 +540,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -540,6 +540,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
element_size, k_inner); element_size, k_inner);
} }
}) })
.def("tl.make_volta_swizzled_layout",
[](int stride, int mat_continuous, bool is_a, bool k_inner) {
return makeGemmVoltaABLayout(stride, mat_continuous, is_a,
k_inner);
})
.def("tl.make_wgmma_swizzled_layout", .def("tl.make_wgmma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size, [](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) { bool k_inner) {
......
...@@ -175,6 +175,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory) ...@@ -175,6 +175,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_mma_sm70)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix) TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -275,6 +275,17 @@ TVM_DLL const Op &ptx_init_tensor_memory(); ...@@ -275,6 +275,17 @@ TVM_DLL const Op &ptx_init_tensor_memory();
*/ */
TVM_DLL const Op &ptx_deallocate_tensor_memory(); TVM_DLL const Op &ptx_deallocate_tensor_memory();
/*!
* \brief tvm intrinsic for ptx tensor core mma instructions on SM70.
*
* void ptx_mma_sm70(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index, bool saturate);
*/
TVM_DLL const Op &ptx_mma_sm70();
/*! /*!
* \brief tvm intrinsics for ldmatrix * \brief tvm intrinsics for ldmatrix
* *
......
...@@ -144,7 +144,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition( ...@@ -144,7 +144,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
int m_warp = 1, n_warp = 1; int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp int kNPerWarp = 8; // Columns processed by a single warp
if (TargetIsVolta(target)) {
kNPerWarp = 16;
}
ICHECK(M % kMPerWarp == 0) ICHECK(M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << M; << "M must be divisible by " << kMPerWarp << ", but got " << M;
ICHECK(N % kNPerWarp == 0) ICHECK(N % kNPerWarp == 0)
......
...@@ -269,6 +269,9 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -269,6 +269,9 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_tcgen05mma_instruction_h_) { if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n"; decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n";
} }
if (need_mma_sm70_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/mma_sm70.h>\n";
}
if (need_tcgen05_common_h_) { if (need_tcgen05_common_h_) {
decl_stream << "#include <tl_templates/cuda/tcgen_05.h>\n"; decl_stream << "#include <tl_templates/cuda/tcgen_05.h>\n";
} }
...@@ -1789,6 +1792,71 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1789,6 +1792,71 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("(C_ptr)", c_ref); replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias); replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call); this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(tl::ptx_mma_sm70())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16
// arg 4: B precision: fp16
// arg 5: C precision: fp16, fp32
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
ICHECK_EQ(op->args.size(), 12U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_bias = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_mma_sm70_instruction_h_ = true;
this->PrintIndent();
std::string mma_call =
"tl::mma_sync_sm70<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)",
tl::codegen::GetMMARegisterType(dtype_a_enum));
replacer.register_rule("(BRegType)",
tl::codegen::GetMMARegisterType(dtype_b_enum));
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", a_bias);
replacer.register_rule("(B_ptr)", b_ref);
replacer.register_rule("(B_offset)", b_bias);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(builtin::ptx_mma_sp())) { } else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX // arg 0: shape: mXnXkX
// arg 1: A layout: row/col // arg 1: A layout: row/col
......
...@@ -114,6 +114,8 @@ private: ...@@ -114,6 +114,8 @@ private:
bool need_wgmma_instruction_h_{false}; bool need_wgmma_instruction_h_{false};
// whether need tl tcgen05mma instruction header // whether need tl tcgen05mma instruction header
bool need_tcgen05mma_instruction_h_{false}; bool need_tcgen05mma_instruction_h_{false};
// whether need tl mma_sm70 instruction header
bool need_mma_sm70_instruction_h_{false};
// whether need tcgen_05 common header // whether need tcgen_05 common header
bool need_tcgen05_common_h_{false}; bool need_tcgen05_common_h_{false};
// whether need cast_smem_ptr_to_int helper function // whether need cast_smem_ptr_to_int helper function
......
#pragma once
#include "../common.h"
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
// SM70 MMA Instruction Traits and Implementations
// SM70 supports m16n16k4 (m8n8k4 instruction at warp level) with FP16/FP32
// accumulation
// Base template for SM70 MMA implementation
template <DataType AType, DataType BType, DataType CType, bool TransA,
bool TransB>
struct MmaSm70Impl {
// Default: unsupported configuration
static constexpr bool kSupported = false;
static TL_DEVICE void exec(void *, const void *, const void *, const void *) {
static_assert(always_false_v<std::integral_constant<bool, TransA>>,
"tl::mma_sync_sm70: unsupported configuration");
}
};
// FP16 inputs, FP16 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP32 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// Helper to extract register types
template <class Impl> struct MmaSm70ImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
// Dispatcher for SM70 MMA operations
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
struct MmaSm70Dispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs and FP16/FP32 "
"accumulation.");
}
};
// Helper to call fma with unpacked register arrays
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c) {
call_fma_impl_sm70<Impl>(
d, a, b, c, std::make_index_sequence<MmaSm70ImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kCRegs>{});
}
// Define dispatchers for all supported SM70 configurations
// Note: m8n8k4 instruction computes m16n16k4 at warp level
#define TL_DEFINE_MMA_SM70_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, \
TransAValue, TransBValue) \
template <> \
struct MmaSm70Dispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, 16, 16, 4, TransAValue, \
TransBValue> { \
using Impl = MmaSm70Impl<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, TransAValue, TransBValue>; \
using Traits = MmaSm70ImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync_sm70 requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma_sm70<Impl>(d, a, b, c); \
} \
};
// FP16 inputs with FP16 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, false)
// FP16 inputs with FP32 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, false)
#undef TL_DEFINE_MMA_SM70_DISPATCHER
} // namespace detail
/// SM70 MMA synchronous instruction wrapper
/// Supports m16n16k4 shape (m8n8k4 instruction at warp level) with FP16 inputs
/// and FP16/FP32 accumulation
///
/// @tparam AType Input A data type (kFloat16)
/// @tparam BType Input B data type (kFloat16)
/// @tparam CType Accumulator/output data type (kFloat16 or kFloat32)
/// @tparam M Matrix M dimension (16)
/// @tparam N Matrix N dimension (16)
/// @tparam K Matrix K dimension (4)
/// @tparam TransA Whether A is transposed (false=row-major, true=col-major)
/// @tparam TransB Whether B is transposed (false=row-major, true=col-major)
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
TL_DEVICE void mma_sync_sm70(
typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA,
TransB>::CRegType *c,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::ARegType *a,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::BRegType *b) {
using Dispatcher =
detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA, TransB>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs.");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
from __future__ import annotations
def shared_16x4_to_mma_a_32x4_layout(row, col, rep):
tid = (row % 4) + 16 * ((row // 4) % 2) + 4 * (row // 8) + 8 * rep
local_id = col
return tid, local_id
def shared_4x16_to_mma_b_32x4_layout(row, col, rep):
thread_id = row + 8 * col // 4 + 4 * rep
local_id = col % 4
return thread_id, local_id
def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
thread_id = row % 4 + 4 * rep + 8 * ((row % 8) // 4) + 16 * (row // 8)
local_id = col
return thread_id, local_id
def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id):
row = (thread_id % 2) + (
(local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8
col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id %
2) + (local_id // 4) * 8
return row, col
def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
row = (thread_id % 4) + (thread_id // 16) * 4 + (thread_id % 8) // 4 * 8
col = local_id % 4 + ((thread_id % 16) // 8) * 4 + (local_id // 4) * 8
return row, col
def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id):
row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4))
col = local_id
return row, col
def mma_load_b_32x4_to_shared_16x4_layout_trans(thread_id, local_id):
row = (thread_id % 4) + 8 * (thread_id // 16) + 4 * ((thread_id // 8) % 2)
col = local_id
return row, col
def mma_load_b_32x4_to_shared_4x16_layout(thread_id, local_id):
row = thread_id % 4
col = local_id + (4 * (thread_id // 8))
return row, col
from __future__ import annotations
import tilelang.language as T
from typing import Literal, Callable
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var
from tvm.runtime import convert
from tilelang.utils import is_fragment
from tilelang.intrinsics.mma_sm70_layout import (
shared_16x4_to_mma_a_32x4_layout,
shared_4x16_to_mma_b_32x4_layout,
shared_16x4_to_mma_b_32x4_layout_trans,
mma_32x8_to_shared_16x16_layout_fp32,
mma_32x8_to_shared_16x16_layout_fp16,
mma_load_a_32x4_to_shared_16x4_layout,
mma_load_b_32x4_to_shared_16x4_layout_trans,
mma_load_b_32x4_to_shared_4x16_layout,
)
lift = convert
class TensorCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM = 16
# use lowercase as n_dim can be dynamic
# the smallest instructions can be m16n8k16, so the n_dim can also be 8
n_dim = 16
WARP_SIZE = 32
HALF_WARP_SIZE = WARP_SIZE // 2
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
}
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: bool | None = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_micro_size(self.M_DIM, self.k_dim)
self._initialize_local_size(self.M_DIM, self.n_dim, self.k_dim)
self._initialize_mma_prefix(self.k_dim)
self._initialize_is_m_first(is_m_first)
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 4
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16):
self.local_size_a = (m_dim * k_dim) // self.HALF_WARP_SIZE
self.local_size_b = (n_dim * k_dim) // self.HALF_WARP_SIZE
self.local_size_out = (m_dim * n_dim) // self.WARP_SIZE
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype)
self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype)
self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype)
def _get_dtype_abbrv(self, dtype: str) -> str:
try:
return self.dtype_abbrv[dtype]
except KeyError as err:
raise ValueError(f"Unsupported dtype: {dtype}") from err
def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 4:
# typically used for float16
self.mma_prefix = "m16n16k4"
else:
raise ValueError(f"Unsupported k_dim: {k_dim}")
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 16, f"warp_col_tiles must be greater than 16, got {warp_col_tiles}"
assert warp_col_tiles % 16 == 0, f"warp_col_tiles must be divisible by 16, got {warp_col_tiles}"
self.warp_rows = warp_row_tiles // m_dim
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
self.micro_size_x = m_dim
self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(
mma_32x8_to_shared_16x16_layout_fp32
if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype="int32")
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
"""
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
# if is_m_first is None, then use the default value
if is_m_first is None:
is_m_first = self.is_m_first
if is_m_first:
lane_id, warp_n, warp_m = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_col_warps,
(thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m
else:
lane_id, warp_m, warp_n = (
thread_id % WARP_SIZE,
(thread_id // WARP_SIZE) % block_row_warps,
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self,
A_local_buf: Buffer,
A_shared_buf: Buffer,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
a_transposed = self.a_transposed
thread_binding = self.get_thread_binding()
assert not a_transposed, "A must be not transposed"
mma_load_layout = mma_load_a_32x4_to_shared_16x4_layout
@T.macro
def _warp_ldmatrix_a(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_rows):
# Assign A_shared_buf_elem
wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
for j in T.vectorized(local_size_a):
mi, mk = mma_load_layout(tx, j)
A_local_buf[i * local_size_a + j] = A_shared_buf[wi + mi, wk + mk]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer,
ki: PrimExpr,
rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
b_transposed = self.b_transposed
thread_binding = self.get_thread_binding()
mma_load_layout = mma_load_b_32x4_to_shared_16x4_layout_trans if b_transposed else mma_load_b_32x4_to_shared_4x16_layout
@T.macro
def _warp_ldmatrix_b(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for i in T.serial(warp_cols):
# Assign B_shared_elem
wi, wk = (
warp_n * warp_col_tiles + i * micro_size_y,
rk * chunk + ki * micro_size_k,
)
# load 16x32 data from shared buffer to local buffer
# must be transposed.
for j in T.vectorized(local_size_b):
if b_transposed:
mi, mk = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wi + mi, wk + mk]
else:
mk, mi = mma_load_layout(tx, j)
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
b_dtype_abbrv = self.b_dtype_abbrv
accum_dtype_abbrv = self.accum_dtype_abbrv
mma_prefix = self.mma_prefix
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
a_major = "col" if self.a_transposed else "row"
b_major = "col" if self.b_transposed else "row"
@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
T.ptx_mma_sm70(
mma_prefix,
a_major,
b_major,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_local_buf.data,
a_local_stride + i * local_size_a,
B_local_buf.data,
b_local_stride + j * local_size_b,
C_local_buf.data,
i * warp_cols * local_size_out + j * local_size_out,
)
return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
assert matrix in ["A", "B"], "matrix should be either A or B"
matrix_is_a: bool = matrix == "A"
matrix_is_b: bool = matrix == "B"
dtype = self.a_dtype if matrix_is_a else self.b_dtype
dtype_bits = DataType(dtype).bits
transposed = self.a_transposed if matrix_is_a else self.b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
transform_func_rs_b: Callable = None
if dtype_bits == 16:
transform_func_sr_a = shared_16x4_to_mma_a_32x4_layout
transform_func_sr_b = shared_16x4_to_mma_b_32x4_layout_trans
transform_func_rs_b = shared_4x16_to_mma_b_32x4_layout
else:
raise ValueError(f"Unsupported dtype {dtype}")
is_sr_conditions = [False]
is_sr_conditions.append(matrix_is_a and not transposed)
is_sr_conditions.append(matrix_is_b and transposed)
is_sr_axis_order = any(is_sr_conditions)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix_is_a:
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
elif matrix_is_b:
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_rs_b(
i, j)
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
else:
micro_size_r, micro_size_s = self.micro_size_k, self.micro_size_y
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward(i: int, j: int, rep: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, local_id = inverse_mma_load_layout.map_indices([i, j, rep])
return lane_id, local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_fn=forward,
replicate=2)
warp_rows, warp_cols = self.warp_rows, self.warp_cols
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
if is_sr_axis_order:
warp_fragment = base_fragment.repeat([warp_s, warp_r],
repeat_on_thread=False,
lower_dim_first=False)
if matrix_is_a:
block_fragment = warp_fragment.repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([block_s, 1],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
else:
warp_fragment = base_fragment.repeat([warp_r, warp_s],
repeat_on_thread=False,
lower_dim_first=True)
if matrix_is_a:
block_fragment = warp_fragment.repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True).replicate(replicate)
elif matrix_is_b:
block_fragment = warp_fragment.replicate(replicate).repeat([1, block_s],
repeat_on_thread=True,
lower_dim_first=True)
else:
raise ValueError(f"Unsupported matrix type {matrix}")
return block_fragment
def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mma_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
return warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
...@@ -708,3 +708,102 @@ def tcgen05_mma_arrive(mbar_ptr): ...@@ -708,3 +708,102 @@ def tcgen05_mma_arrive(mbar_ptr):
Pointer to the mbarrier object in shared memory (e.g., Barrier*). Pointer to the mbarrier object in shared memory (e.g., Barrier*).
""" """
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr) return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
def ptx_mma_sm70(
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
):
"""TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).
This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape
with FP16 inputs and FP16/FP32 accumulation.
Parameters
----------
shape : str
The shape of mma fragment (e.g., "m16n16k4").
A_layout : str
The layout of multiplicand fragment A ("row" or "col").
B_layout : str
The layout of multiplicand fragment B ("row" or "col").
A_dtype : str
The data type of multiplicand fragment A (typically "fp16").
B_dtype : str
The data type of multiplicand fragment B (typically "fp16").
C_dtype : str
The data type of accumulator fragment C ("fp16" or "fp32").
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment B.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
Examples
--------
>>> T.ptx_mma_sm70(
... "float16",
... "m16n16k4",
... "row",
... "col",
... "fp16",
... "fp16",
... "fp16",
... A_local.data,
... 0,
... B_local.data,
... 0,
... C_local.data,
... 0,
... )
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.ptx_mma_sm70"),
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
)
...@@ -5,6 +5,7 @@ from .layout import Layout # noqa: F401 ...@@ -5,6 +5,7 @@ from .layout import Layout # noqa: F401
from .fragment import Fragment # noqa: F401 from .fragment import Fragment # noqa: F401
from .swizzle import ( from .swizzle import (
make_swizzled_layout, # noqa: F401 make_swizzled_layout, # noqa: F401
make_volta_swizzled_layout, # noqa: F401
make_wgmma_swizzled_layout, # noqa: F401 make_wgmma_swizzled_layout, # noqa: F401
make_tcgen05mma_swizzled_layout, # noqa: F401 make_tcgen05mma_swizzled_layout, # noqa: F401
make_full_bank_swizzled_layout, # noqa: F401 make_full_bank_swizzled_layout, # noqa: F401
......
...@@ -18,6 +18,17 @@ def make_swizzled_layout(buffer: tvm.tir.Buffer, k_major: bool = True, allow_pad ...@@ -18,6 +18,17 @@ def make_swizzled_layout(buffer: tvm.tir.Buffer, k_major: bool = True, allow_pad
) )
# for Volta Intrinsics
def make_volta_swizzled_layout(buffer: tvm.tir.Buffer, is_a: bool = True, k_inner: bool = True):
assert len(buffer.shape) == 2
return _ffi_api.make_volta_swizzled_layout(
int(buffer.shape[0]),
int(buffer.shape[1]),
is_a,
k_inner,
)
# for WGMMA Intrinsics # for WGMMA Intrinsics
def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
continuity: int = None, continuity: int = None,
......
...@@ -7,10 +7,12 @@ from tvm.runtime import Scriptable ...@@ -7,10 +7,12 @@ from tvm.runtime import Scriptable
import tvm_ffi import tvm_ffi
from tilelang.ir import GemmWarpPolicy from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA from .gemm_mma import GemmMMA
from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA from .gemm_wgmma import GemmWGMMA
from .gemm_tcgen05 import GemmTCGEN5 from .gemm_tcgen05 import GemmTCGEN5
from .gemm_mfma import GemmMFMA from .gemm_mfma import GemmMFMA
from tilelang import _ffi_api from tilelang import _ffi_api
from tilelang.utils.target import target_is_volta
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout") @tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
...@@ -79,13 +81,13 @@ class GemmPy(Node, Scriptable): ...@@ -79,13 +81,13 @@ class GemmPy(Node, Scriptable):
def infer_layout(self, target: Target, thread_nums: int): def infer_layout(self, target: Target, thread_nums: int):
"""Infer the layout for the GEMM operation based on target architecture.""" """Infer the layout for the GEMM operation based on target architecture."""
gemm_inst = self._select_gemm_instruction(thread_nums, target) gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst) impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).infer_layout(target, thread_nums) return impl_class(self).infer_layout(target, thread_nums)
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
"""Lower the GEMM operation to TIR statements based on target architecture.""" """Lower the GEMM operation to TIR statements based on target architecture."""
gemm_inst = self._select_gemm_instruction(thread_nums, target) gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst) impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).lower(layout_map, target, thread_nums, thread_var) return impl_class(self).lower(layout_map, target, thread_nums, thread_var)
def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst: def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst:
...@@ -106,7 +108,7 @@ class GemmPy(Node, Scriptable): ...@@ -106,7 +108,7 @@ class GemmPy(Node, Scriptable):
""" """
return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target)) return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target))
def _get_implementation_class(self, gemm_inst: GemmInst): def _get_implementation_class(self, gemm_inst: GemmInst, target: Target):
"""Get the appropriate implementation class for the given GEMM instruction. """Get the appropriate implementation class for the given GEMM instruction.
Args: Args:
...@@ -120,6 +122,8 @@ class GemmPy(Node, Scriptable): ...@@ -120,6 +122,8 @@ class GemmPy(Node, Scriptable):
ValueError: If the instruction type is unknown ValueError: If the instruction type is unknown
""" """
if gemm_inst.is_mma(): if gemm_inst.is_mma():
if target_is_volta(target):
return GemmMMASm70
return GemmMMA return GemmMMA
elif gemm_inst.is_wgmma(): elif gemm_inst.is_wgmma():
return GemmWGMMA return GemmWGMMA
......
# for Volta GPUs, which use legacy MMA instructions
from .gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMMASm70(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
a_is_k_major = not self.trans_A
b_is_k_major = self.trans_B
if self.is_gemm_ss():
return {
self.A: make_volta_swizzled_layout(self.A, is_a=True, k_inner=a_is_k_major),
self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment