"driver/driver.hip.cpp" did not exist on "67c6f73ffe0dc06659757c8e28901187394de77b"
Unverified Commit 4a9cb470 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[SM70] Refactor and minor fix for SM70 (#1195)

* [Feature] Add support for SM70 tensor core MMA instructions

- Introduced new intrinsic `ptx_mma_sm70` for Volta GPUs, enabling m16n16k4 shape with FP16 inputs and FP16/FP32 accumulation.
- Added `GemmMMASm70` class for handling GEMM operations specific to SM70 architecture.
- Implemented layout functions for Volta swizzled layouts and updated existing GEMM layout inference logic.
- Updated `requirements-dev.txt` to include `apache-tvm-ffi` dependency.
- Added correctness evaluation script for testing GEMM operations on SM70.

* [Refactor] Update formatting and installation commands in scripts

- Modified `format.sh` to install `pre-commit` and `clang-tidy` with the `--user` flag for user-specific installations.
- Improved readability in `correctness_evaluation_sm70.py` by adjusting the formatting of pytest parameters.
- Cleaned up spacing and formatting in various C++ source files for better consistency and readability.
- Removed unnecessary comments and improved layout function definitions in `mma_sm70_layout.py` and `mma_sm70_macro_generator.py` for clarity.
- Ensured consistent formatting in layout initialization and swizzle functions.

* typo fix
parent 11456de2
......@@ -85,7 +85,7 @@ export PIP_USER=0
# If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit
python3 -m pip install pre-commit --user
fi
echo 'tile-lang pre-commit: Check Start'
......@@ -115,7 +115,7 @@ echo 'tile-lang clang-tidy: Check Start'
if [[ -x "$(command -v run-clang-tidy)" ]]; then
# Check if clang-tidy is available
if [[ ! -x "$(command -v clang-tidy)" ]]; then
python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt"
python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" --user
fi
# Get clang-tidy version
CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')"
......
# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [64, 128]
N_VALUES = [16, 32, 64, 128]
K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
] + [
pytest.param(
k,
"float16",
"float16",
"float32",
id=f"K{k}-float16-float16-float32",
) for k in K_VALUES
])
def _ensure_torch_dtypes(*dtype_names):
import torch
for name in set(dtype_names):
if not hasattr(torch, name):
pytest.skip(f"Torch does not expose dtype {name}")
def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128)
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
TRANS_CASES = [
pytest.param(False, False, id="nn"),
pytest.param(False, True, id="nt"),
pytest.param(True, False, id="tn"),
pytest.param(True, True, id="tt"),
]
@pytest.fixture(scope="module", autouse=True)
def _setup_tilelang_environment():
tilelang.disable_cache()
tilelang.testing.set_random_seed(42)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_false_false(m, n, k):
run_gemm(
m,
n,
k * 3,
False,
False,
"float16",
"float16",
"float16",
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
_ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype)
run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rs_false_false(m, n, k)
if __name__ == "__main__":
tilelang.testing.main()
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
......@@ -63,7 +63,7 @@ N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 64
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
......
# Requirements to run local build with `--no-build-isolation` or other developments
apache-tvm-ffi~=0.1.0
build
cmake>=3.26
cython>=3.0.0
......
......@@ -577,11 +577,11 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
bool k_inner) {
if (k_inner)
if (k_inner && continuous % 32 == 0 && stride % 32 == 0)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0)
if (is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0)
if (!is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaBLayoutCongruous(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, 16);
}
......
......@@ -540,6 +540,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
element_size, k_inner);
}
})
.def("tl.make_volta_swizzled_layout",
[](int stride, int mat_continuous, bool is_a, bool k_inner) {
return makeGemmVoltaABLayout(stride, mat_continuous, is_a,
k_inner);
})
.def("tl.make_wgmma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
......
......@@ -175,6 +175,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_mma_sm70)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -275,6 +275,17 @@ TVM_DLL const Op &ptx_init_tensor_memory();
*/
TVM_DLL const Op &ptx_deallocate_tensor_memory();
/*!
* \brief tvm intrinsic for ptx tensor core mma instructions on SM70.
*
* void ptx_mma_sm70(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index, bool saturate);
*/
TVM_DLL const Op &ptx_mma_sm70();
/*!
* \brief tvm intrinsics for ldmatrix
*
......
......@@ -144,7 +144,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
int kNPerWarp = 8; // Columns processed by a single warp
if (TargetIsVolta(target)) {
kNPerWarp = 16;
}
ICHECK(M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << M;
ICHECK(N % kNPerWarp == 0)
......
......@@ -269,6 +269,9 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n";
}
if (need_mma_sm70_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/mma_sm70.h>\n";
}
if (need_tcgen05_common_h_) {
decl_stream << "#include <tl_templates/cuda/tcgen_05.h>\n";
}
......@@ -1789,6 +1792,71 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(tl::ptx_mma_sm70())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16
// arg 4: B precision: fp16
// arg 5: C precision: fp16, fp32
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
ICHECK_EQ(op->args.size(), 12U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_bias = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_mma_sm70_instruction_h_ = true;
this->PrintIndent();
std::string mma_call =
"tl::mma_sync_sm70<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)",
tl::codegen::GetMMARegisterType(dtype_a_enum));
replacer.register_rule("(BRegType)",
tl::codegen::GetMMARegisterType(dtype_b_enum));
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", a_bias);
replacer.register_rule("(B_ptr)", b_ref);
replacer.register_rule("(B_offset)", b_bias);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
......
......@@ -114,6 +114,8 @@ private:
bool need_wgmma_instruction_h_{false};
// whether need tl tcgen05mma instruction header
bool need_tcgen05mma_instruction_h_{false};
// whether need tl mma_sm70 instruction header
bool need_mma_sm70_instruction_h_{false};
// whether need tcgen_05 common header
bool need_tcgen05_common_h_{false};
// whether need cast_smem_ptr_to_int helper function
......
#pragma once
#include "../common.h"
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
// SM70 MMA Instruction Traits and Implementations
// SM70 supports m16n16k4 (m8n8k4 instruction at warp level) with FP16/FP32
// accumulation
// Base template for SM70 MMA implementation
template <DataType AType, DataType BType, DataType CType, bool TransA,
bool TransB>
struct MmaSm70Impl {
// Default: unsupported configuration
static constexpr bool kSupported = false;
static TL_DEVICE void exec(void *, const void *, const void *, const void *) {
static_assert(always_false_v<std::integral_constant<bool, TransA>>,
"tl::mma_sync_sm70: unsupported configuration");
}
};
// FP16 inputs, FP16 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP32 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// Helper to extract register types
template <class Impl> struct MmaSm70ImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
// Dispatcher for SM70 MMA operations
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
struct MmaSm70Dispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs and FP16/FP32 "
"accumulation.");
}
};
// Helper to call fma with unpacked register arrays
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c) {
call_fma_impl_sm70<Impl>(
d, a, b, c, std::make_index_sequence<MmaSm70ImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kCRegs>{});
}
// Define dispatchers for all supported SM70 configurations
// Note: m8n8k4 instruction computes m16n16k4 at warp level
#define TL_DEFINE_MMA_SM70_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, \
TransAValue, TransBValue) \
template <> \
struct MmaSm70Dispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, 16, 16, 4, TransAValue, \
TransBValue> { \
using Impl = MmaSm70Impl<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, TransAValue, TransBValue>; \
using Traits = MmaSm70ImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync_sm70 requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma_sm70<Impl>(d, a, b, c); \
} \
};
// FP16 inputs with FP16 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, false)
// FP16 inputs with FP32 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, false)
#undef TL_DEFINE_MMA_SM70_DISPATCHER
} // namespace detail
/// SM70 MMA synchronous instruction wrapper
/// Supports m16n16k4 shape (m8n8k4 instruction at warp level) with FP16 inputs
/// and FP16/FP32 accumulation
///
/// @tparam AType Input A data type (kFloat16)
/// @tparam BType Input B data type (kFloat16)
/// @tparam CType Accumulator/output data type (kFloat16 or kFloat32)
/// @tparam M Matrix M dimension (16)
/// @tparam N Matrix N dimension (16)
/// @tparam K Matrix K dimension (4)
/// @tparam TransA Whether A is transposed (false=row-major, true=col-major)
/// @tparam TransB Whether B is transposed (false=row-major, true=col-major)
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
TL_DEVICE void mma_sync_sm70(
typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA,
TransB>::CRegType *c,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::ARegType *a,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::BRegType *b) {
using Dispatcher =
detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA, TransB>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs.");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
from __future__ import annotations
def shared_16x4_to_mma_a_32x4_layout(row, col, rep):
tid = (row % 4) + 16 * ((row // 4) % 2) + 4 * (row // 8) + 8 * rep
local_id = col
return tid, local_id
def shared_4x16_to_mma_b_32x4_layout(row, col, rep):
thread_id = row + 8 * col // 4 + 4 * rep
local_id = col % 4
return thread_id, local_id
def shared_16x4_to_mma_b_32x4_layout_trans(row, col, rep):
thread_id = row % 4 + 4 * rep + 8 * ((row % 8) // 4) + 16 * (row // 8)
local_id = col
return thread_id, local_id
def mma_32x8_to_shared_16x16_layout_fp32(thread_id, local_id):
row = (thread_id % 2) + (
(local_id // 2 % 2) * 2) + 4 * (thread_id // 16) + (thread_id % 16 // 4) % 2 * 8
col = (thread_id % 4 // 2) * 2 + (thread_id % 16 // 8) * 4 + (local_id %
2) + (local_id // 4) * 8
return row, col
def mma_32x8_to_shared_16x16_layout_fp16(thread_id, local_id):
row = (thread_id % 4) + (thread_id // 16) * 4 + (thread_id % 8) // 4 * 8
col = local_id % 4 + ((thread_id % 16) // 8) * 4 + (local_id // 4) * 8
return row, col
def mma_load_a_32x4_to_shared_16x4_layout(thread_id, local_id):
row = (thread_id % 4) + (4 * (((thread_id // 16 + thread_id % 16 // 4 * 2)) % 4))
col = local_id
return row, col
def mma_load_b_32x4_to_shared_16x4_layout_trans(thread_id, local_id):
row = (thread_id % 4) + 8 * (thread_id // 16) + 4 * ((thread_id // 8) % 2)
col = local_id
return row, col
def mma_load_b_32x4_to_shared_4x16_layout(thread_id, local_id):
row = thread_id % 4
col = local_id + (4 * (thread_id // 8))
return row, col
This diff is collapsed.
......@@ -708,3 +708,102 @@ def tcgen05_mma_arrive(mbar_ptr):
Pointer to the mbarrier object in shared memory (e.g., Barrier*).
"""
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
def ptx_mma_sm70(
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
):
"""TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).
This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape
with FP16 inputs and FP16/FP32 accumulation.
Parameters
----------
shape : str
The shape of mma fragment (e.g., "m16n16k4").
A_layout : str
The layout of multiplicand fragment A ("row" or "col").
B_layout : str
The layout of multiplicand fragment B ("row" or "col").
A_dtype : str
The data type of multiplicand fragment A (typically "fp16").
B_dtype : str
The data type of multiplicand fragment B (typically "fp16").
C_dtype : str
The data type of accumulator fragment C ("fp16" or "fp32").
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment B.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
Examples
--------
>>> T.ptx_mma_sm70(
... "float16",
... "m16n16k4",
... "row",
... "col",
... "fp16",
... "fp16",
... "fp16",
... A_local.data,
... 0,
... B_local.data,
... 0,
... C_local.data,
... 0,
... )
"""
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.ptx_mma_sm70"),
shape,
A_layout,
B_layout,
A_dtype,
B_dtype,
C_dtype,
multiplicand_a,
a_index,
multiplicand_b,
b_index,
accumulator,
c_index,
)
......@@ -5,6 +5,7 @@ from .layout import Layout # noqa: F401
from .fragment import Fragment # noqa: F401
from .swizzle import (
make_swizzled_layout, # noqa: F401
make_volta_swizzled_layout, # noqa: F401
make_wgmma_swizzled_layout, # noqa: F401
make_tcgen05mma_swizzled_layout, # noqa: F401
make_full_bank_swizzled_layout, # noqa: F401
......
......@@ -18,6 +18,17 @@ def make_swizzled_layout(buffer: tvm.tir.Buffer, k_major: bool = True, allow_pad
)
# for Volta Intrinsics
def make_volta_swizzled_layout(buffer: tvm.tir.Buffer, is_a: bool = True, k_inner: bool = True):
assert len(buffer.shape) == 2
return _ffi_api.make_volta_swizzled_layout(
int(buffer.shape[0]),
int(buffer.shape[1]),
is_a,
k_inner,
)
# for WGMMA Intrinsics
def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
continuity: int = None,
......
......@@ -7,10 +7,12 @@ from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA
from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA
from .gemm_tcgen05 import GemmTCGEN5
from .gemm_mfma import GemmMFMA
from tilelang import _ffi_api
from tilelang.utils.target import target_is_volta
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
......@@ -79,13 +81,13 @@ class GemmPy(Node, Scriptable):
def infer_layout(self, target: Target, thread_nums: int):
"""Infer the layout for the GEMM operation based on target architecture."""
gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst)
impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).infer_layout(target, thread_nums)
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
"""Lower the GEMM operation to TIR statements based on target architecture."""
gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst)
impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).lower(layout_map, target, thread_nums, thread_var)
def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst:
......@@ -106,7 +108,7 @@ class GemmPy(Node, Scriptable):
"""
return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target))
def _get_implementation_class(self, gemm_inst: GemmInst):
def _get_implementation_class(self, gemm_inst: GemmInst, target: Target):
"""Get the appropriate implementation class for the given GEMM instruction.
Args:
......@@ -120,6 +122,8 @@ class GemmPy(Node, Scriptable):
ValueError: If the instruction type is unknown
"""
if gemm_inst.is_mma():
if target_is_volta(target):
return GemmMMASm70
return GemmMMA
elif gemm_inst.is_wgmma():
return GemmWGMMA
......
# for Volta GPUs, which use legacy MMA instructions
from .gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMMASm70(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
a_is_k_major = not self.trans_A
b_is_k_major = self.trans_B
if self.is_gemm_ss():
return {
self.A: make_volta_swizzled_layout(self.A, is_a=True, k_inner=a_is_k_major),
self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment