Commit 69a74571 authored by qisan's avatar qisan
Browse files

feat(dcu): switch to gemm_v1 instead of gemm_v2

parent cf6e11c9
...@@ -83,6 +83,8 @@ TileLang achieves exceptional performance across a variety of computational patt ...@@ -83,6 +83,8 @@ TileLang achieves exceptional performance across a variety of computational patt
</div> </div>
## Installation ## Installation
### Method 0: Install for Hygon DCU
- [Install from Source](./docs/get_started/Installation_dcu.md)
### Method 1: Install with Pip ### Method 1: Install with Pip
The quickest way to get started is to install the latest release from PyPI: The quickest way to get started is to install the latest release from PyPI:
......
# Installation for DCU
## Building from Source
```bash
mkdir -p build
cd build
cmake .. -DUSE_CUDA=ON -DUSE_ROCM=OFF
make -j
```
```bash
export PYTHONPATH=/path/to/tilelang:$PYTHONPATH
python -c "import tilelang; print(tilelang.__version__)"
```
## Other Tips
### Missing tvm_ffi Module
If you encounter the error ModuleNotFoundError: No module named 'tvm_ffi', it means the TVM foreign function interface package was not installed. This often happens if the submodules were built manually. Fix it by running:
```
# Navigate to the tvm_ffi directory
cd 3rdparty/tvm/3rdparty/tvm_ffi
# Install the package in editable mode
pip install .
# Return to the project root
cd ../../../..
```
### DTK Path Configuration
If you encounter errors related to DTK path detection (e.g., hipcc not found or failure to retrieve GPU architecture), you may need to manually specify the DTK installation path in the source code.
Locate the file tilelang/contrib/rocm.py and modify the default value of the rocm_path parameter in the get_rocm_arch function (around line 231):
```
# File: tilelang/contrib/rocm.py
# Change from:
def get_rocm_arch(rocm_path="/opt/rocm"):
...
# To (for Hygon DCU environments):
def get_rocm_arch(rocm_path="/opt/dtk"):
...
```
\ No newline at end of file
...@@ -453,7 +453,7 @@ TVM_DLL const Op &tvm_mfma(); ...@@ -453,7 +453,7 @@ TVM_DLL const Op &tvm_mfma();
/*! /*!
* \brief tvm intrinsic for amd matrix core mmac instructions. * \brief tvm intrinsic for amd matrix core mmac instructions.
* *
* void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout, * void tvm_mmac(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index, * Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index, * Var multiplicand_b, Expr b_index,
......
...@@ -131,6 +131,8 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const { ...@@ -131,6 +131,8 @@ GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
return GemmInst::kTCGEN5MMA; return GemmInst::kTCGEN5MMA;
} else if (allowWgmma(block_size, target)) { } else if (allowWgmma(block_size, target)) {
return GemmInst::kWGMMA; return GemmInst::kWGMMA;
} else if(TargetIsDCU(target)) {
return GemmInst::KMMAC;
} else if (TargetIsCDNA(target)) { } else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA; return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) { } else if (TargetIsCuda(target)) {
......
...@@ -23,7 +23,7 @@ enum class GemmWarpPolicyType : uint8_t { ...@@ -23,7 +23,7 @@ enum class GemmWarpPolicyType : uint8_t {
}; };
// Target GEMM instruction // Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA }; enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA, KMMAC };
class GemmWarpPolicyNode : public Object { class GemmWarpPolicyNode : public Object {
public: public:
mutable int m_warp{0}; mutable int m_warp{0};
......
...@@ -131,6 +131,8 @@ GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const { ...@@ -131,6 +131,8 @@ GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
return GemmInst::kTCGEN5MMA; return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) { } else if (allow_wgmma) {
return GemmInst::kWGMMA; return GemmInst::kWGMMA;
} else if(TargetIsDCU(target)) {
return GemmInst::KMMAC;
} else if (TargetIsCDNA(target)) { } else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA; return GemmInst::kMFMA;
} else if (TargetIsVolta(target) || TargetIsAmpere(target) || } else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
......
...@@ -85,6 +85,7 @@ bool TargetIsDCU(Target target) { ...@@ -85,6 +85,7 @@ bool TargetIsDCU(Target target) {
if (!TargetIsRocm(target)) if (!TargetIsRocm(target))
return false; return false;
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx936", it is DCU // if mcpu start with "gfx936", it is DCU
return mcpu.find("gfx936") == 0; return mcpu.find("gfx936") == 0;
} }
......
...@@ -228,7 +228,7 @@ def have_matrixcore(compute_version=None): ...@@ -228,7 +228,7 @@ def have_matrixcore(compute_version=None):
@tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True) @tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"): def get_rocm_arch(rocm_path="/opt/dtk"):
# @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True) # @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
# def get_rocm_arch(rocm_path="/opt/dtk"): # def get_rocm_arch(rocm_path="/opt/dtk"):
"""Utility function to get the AMD GPU architecture """Utility function to get the AMD GPU architecture
......
...@@ -237,7 +237,7 @@ class Environment: ...@@ -237,7 +237,7 @@ class Environment:
# Kernel selection options # Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1 # Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "0") TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "1")
# Auto-tuning settings # Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0")
......
def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 4
return i, j
\ No newline at end of file
...@@ -2,10 +2,29 @@ from __future__ import annotations ...@@ -2,10 +2,29 @@ from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr from tvm.tir import PrimExpr, IndexMap, Buffer, Var, BufferRegion, BufferLoad
from tvm.runtime import convert from tvm.runtime import convert
from .utils import ( from .utils import (
mfma_store_index_map, mmac_store_index_map,
)
from tilelang.utils import is_fragment
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
) )
lift = convert lift = convert
...@@ -39,9 +58,9 @@ class MatrixCoreIntrinEmitter: ...@@ -39,9 +58,9 @@ class MatrixCoreIntrinEmitter:
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
...@@ -54,6 +73,7 @@ class MatrixCoreIntrinEmitter: ...@@ -54,6 +73,7 @@ class MatrixCoreIntrinEmitter:
k_pack: int | None = None, k_pack: int | None = None,
is_m_first: bool | None = False, is_m_first: bool | None = False,
b_preshuffle: bool | None = False, b_preshuffle: bool | None = False,
thread_var: Var | None = None,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
self.b_dtype = b_dtype self.b_dtype = b_dtype
...@@ -80,10 +100,11 @@ class MatrixCoreIntrinEmitter: ...@@ -80,10 +100,11 @@ class MatrixCoreIntrinEmitter:
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz", "int8"]: if a_dtype in ["float8_e4m3fnuz", T.int8]:
self.k_dim = 32 self.k_dim = 32
return return
a_dtype = DataType(a_dtype) a_dtype = DataType(a_dtype)
...@@ -132,25 +153,6 @@ class MatrixCoreIntrinEmitter: ...@@ -132,25 +153,6 @@ class MatrixCoreIntrinEmitter:
self.b_preshuffle = b_preshuffle self.b_preshuffle = b_preshuffle
def get_ldmatrix_index_map(self, is_b=False): def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
shared_16x64_to_local_64x16_layout_A,
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
thread_id_shared_access_64x16_to_16x64_layout_A,
thread_id_shared_access_64x16_to_16x64_layout_B,
)
k_dim = self.k_dim * self.k_pack k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed transposed = self.a_transposed if not is_b else self.b_transposed
if k_dim == 4: if k_dim == 4:
...@@ -199,6 +201,22 @@ class MatrixCoreIntrinEmitter: ...@@ -199,6 +201,22 @@ class MatrixCoreIntrinEmitter:
return index_map, reverse_index_map return index_map, reverse_index_map
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mmac_store_index_map, index_dtype=T.int32)
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
return inverse_index_map
def get_thread_binding(self):
if self.thread_var is None:
current_frame = T.KernelLaunchFrame.Current()
assert current_frame is not None, "Must be called in a T.Kernel Frame"
return current_frame.get_thread_binding()
else:
return self.thread_var
def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
""" """
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
...@@ -228,7 +246,7 @@ class MatrixCoreIntrinEmitter: ...@@ -228,7 +246,7 @@ class MatrixCoreIntrinEmitter:
) )
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -237,10 +255,15 @@ class MatrixCoreIntrinEmitter: ...@@ -237,10 +255,15 @@ class MatrixCoreIntrinEmitter:
local_size_a = self.local_size_a local_size_a = self.local_size_a
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.a_transposed is_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
# legalize shared buffer to region
A_region = self._legalize_to_buffer_region(A_shared_buf)
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
A_local_buf, A_local_buf,
...@@ -255,17 +278,17 @@ class MatrixCoreIntrinEmitter: ...@@ -255,17 +278,17 @@ class MatrixCoreIntrinEmitter:
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, rk=0): def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -274,10 +297,15 @@ class MatrixCoreIntrinEmitter: ...@@ -274,10 +297,15 @@ class MatrixCoreIntrinEmitter:
local_size_b = self.local_size_b local_size_b = self.local_size_b
k_pack = self.k_pack k_pack = self.k_pack
is_transposed = self.b_transposed is_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
# legalize shared buffer to region
B_region = self._legalize_to_buffer_region(B_shared_buf)
B_buf = B_region.buffer
B_base0 = B_region.region[-2].min
B_base1 = B_region.region[-1].min
@T.macro @T.macro
def _warp_ldmatrix_b( def _warp_ldmatrix_b(
B_local_buf, B_local_buf,
...@@ -295,7 +323,7 @@ class MatrixCoreIntrinEmitter: ...@@ -295,7 +323,7 @@ class MatrixCoreIntrinEmitter:
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, r + col] B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
...@@ -305,11 +333,11 @@ class MatrixCoreIntrinEmitter: ...@@ -305,11 +333,11 @@ class MatrixCoreIntrinEmitter:
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, r + col] B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
def mmac(self, A_local_buf, B_local_buf, C_local_buf): def mmac(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -322,8 +350,13 @@ class MatrixCoreIntrinEmitter: ...@@ -322,8 +350,13 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}"
compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}"
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * k_pack * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * k_pack * local_size_b if b_is_fragment else 0
@T.macro @T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf): def _warp_mmac(A_local_buf, B_local_buf, C_local_buf):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols): for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
T.tvm_mmac( T.tvm_mmac(
mmac_suffix, mmac_suffix,
...@@ -333,15 +366,15 @@ class MatrixCoreIntrinEmitter: ...@@ -333,15 +366,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype, compute_b_dtype,
compute_out_dtype, compute_out_dtype,
A_local_buf.data, A_local_buf.data,
((j * k_pack + kp) * local_size_a) // local_size_a, (a_local_stride + (j * k_pack + kp) * local_size_a) // local_size_a,
B_local_buf.data, B_local_buf.data,
((i * k_pack + kp) * local_size_b) // local_size_b, (b_local_stride + (i * k_pack + kp) * local_size_b) // local_size_b,
C_local_buf.data, C_local_buf.data,
(i * warp_cols * local_size_out + j * local_size_out) // local_size_out, (i * warp_cols * local_size_out + j * local_size_out) // local_size_out,
dtype=compute_out_dtype, dtype=compute_out_dtype,
) )
return _warp_mma(A_local_buf, B_local_buf, C_local_buf) return _warp_mmac(A_local_buf, B_local_buf, C_local_buf)
def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None):
block_row_warps = self.block_row_warps block_row_warps = self.block_row_warps
...@@ -349,8 +382,7 @@ class MatrixCoreIntrinEmitter: ...@@ -349,8 +382,7 @@ class MatrixCoreIntrinEmitter:
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_out = self.local_size_out local_size_out = self.local_size_out
current_frame = T.KernelLaunchFrame.Current() thread_binding = self.get_thread_binding()
thread_binding = current_frame.get_thread_binding()
is_global = pid_m is not None and pid_n is not None is_global = pid_m is not None and pid_n is not None
BLOCK_M = block_row_warps * warp_rows BLOCK_M = block_row_warps * warp_rows
BLOCK_N = block_col_warps * warp_cols BLOCK_N = block_col_warps * warp_cols
...@@ -359,7 +391,7 @@ class MatrixCoreIntrinEmitter: ...@@ -359,7 +391,7 @@ class MatrixCoreIntrinEmitter:
assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D" assert C_buf_dims in {2, 4}, "C_buf should be 2D or 4D"
# STS # STS
# MMA Store must be in simulated instead of TVM Intrins # MMAC Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always # As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size # equal to the warp_size
@T.macro @T.macro
...@@ -367,7 +399,7 @@ class MatrixCoreIntrinEmitter: ...@@ -367,7 +399,7 @@ class MatrixCoreIntrinEmitter:
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.vectorized(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mmac_store_index_map(tx, local_id))
if C_buf_dims == 2: if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[ C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[
j * (warp_rows * local_size_out) + i * local_size_out + local_id j * (warp_rows * local_size_out) + i * local_size_out + local_id
...@@ -382,7 +414,7 @@ class MatrixCoreIntrinEmitter: ...@@ -382,7 +414,7 @@ class MatrixCoreIntrinEmitter:
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.vectorized(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mmac_store_index_map(tx, local_id))
C_buf[ C_buf[
(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col
] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id] ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
...@@ -393,7 +425,93 @@ class MatrixCoreIntrinEmitter: ...@@ -393,7 +425,93 @@ class MatrixCoreIntrinEmitter:
else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding) else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
) )
def make_mmac_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MMAC results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object describing the thread and index layout for MMAC.
"""
from tilelang.utils import is_fragment
shape = local_buf.shape
inverse_mmac_store_layout = self.get_store_index_map(inverse=True)
assert is_fragment(local_buf), "local_buf must be a fragment"
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_y
local_size_out = self.local_size_out
block_row_warps, block_col_warps = self.block_row_warps, self.block_col_warps
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Map fragment row `i` and column `j` to a thread index.
"""
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
mmac_i, mmac_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mmac_store_layout.map_indices([mmac_i, mmac_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_size + lane_id
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
def forward_index(i: int, j: int) -> int:
"""
Map fragment row `i` and column `j` to a local index within a thread's registers.
"""
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
mmac_i, mmac_j = i % micro_size_x, j % micro_size_y
# 使用 MMAC 的底层硬件逆映射获取局部偏移
_, local_id = inverse_mmac_store_layout.map_indices([mmac_i, mmac_j])
return warp_j * (warp_rows * local_size_out) + warp_i * local_size_out + local_id
return T.Fragment(
shape,
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
@staticmethod
def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if isinstance(obj, BufferRegion):
return obj
if isinstance(obj, Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return BufferRegion(obj, ranges)
if isinstance(obj, BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
def __init__( def __init__(
self, self,
...@@ -413,33 +531,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -413,33 +531,27 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
is_m_first: bool | None = False, is_m_first: bool | None = False,
a_preshuffle: bool | None = False, a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False, b_preshuffle: bool | None = False,
thread_var: Var | None = None,
): ):
self.a_dtype = a_dtype super().__init__(
self.b_dtype = b_dtype a_dtype=a_dtype,
self.accum_dtype = accum_dtype b_dtype=b_dtype,
self.a_transposed = a_transposed accum_dtype=accum_dtype,
self.b_transposed = b_transposed a_transposed=a_transposed,
# Hint Information b_transposed=b_transposed,
self.block_row_warps = block_row_warps block_row_warps=block_row_warps,
self.block_col_warps = block_col_warps block_col_warps=block_col_warps,
self.warp_row_tiles = warp_row_tiles warp_row_tiles=warp_row_tiles,
self.warp_col_tiles = warp_col_tiles warp_col_tiles=warp_col_tiles,
self.chunk = chunk chunk=chunk,
self._initialize_k_dim(a_dtype) reduce_k=reduce_k,
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) num_elems_per_byte=num_elems_per_byte,
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) k_pack=k_pack,
self._initialize_mmac_prefix(self.k_dim) is_m_first=is_m_first,
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) thread_var=thread_var,
self._initialize_k_pack(k_pack) )
self._initialize_is_m_first(is_m_first)
self._initialize_preshuffle(a_preshuffle, b_preshuffle) self._initialize_preshuffle(a_preshuffle, b_preshuffle)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
if a_preshuffle is not None: if a_preshuffle is not None:
self.a_preshuffle = a_preshuffle self.a_preshuffle = a_preshuffle
......
...@@ -90,6 +90,9 @@ def mma_store_index_map_fp64(thread_id, local_id): ...@@ -90,6 +90,9 @@ def mma_store_index_map_fp64(thread_id, local_id):
def mfma_store_index_map(thread_id, local_id): def mfma_store_index_map(thread_id, local_id):
return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id) return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id)
def mmac_store_index_map(thread_id, local_id):
return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id)
def get_mma_micro_size(dtype: Literal["float16", "int8"]): def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# TODO(lei): FP8 related precision support. # TODO(lei): FP8 related precision support.
......
...@@ -11,6 +11,7 @@ from .gemm_mma_sm70 import GemmMMASm70 ...@@ -11,6 +11,7 @@ from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA from .gemm_wgmma import GemmWGMMA
from .gemm_tcgen05 import GemmTCGEN5 from .gemm_tcgen05 import GemmTCGEN5
from .gemm_mfma import GemmMFMA from .gemm_mfma import GemmMFMA
from .gemm_mmac import GemmMMAC
from tilelang import _ffi_api from tilelang import _ffi_api
from tilelang.utils.target import target_is_volta from tilelang.utils.target import target_is_volta
...@@ -35,6 +36,7 @@ class GemmInst(IntEnum): ...@@ -35,6 +36,7 @@ class GemmInst(IntEnum):
WGMMA = 1 WGMMA = 1
TCGEN5MMA = 2 TCGEN5MMA = 2
MFMA = 3 MFMA = 3
MMAC = 4
def is_mma(self) -> bool: def is_mma(self) -> bool:
return self == GemmInst.MMA return self == GemmInst.MMA
...@@ -47,6 +49,9 @@ class GemmInst(IntEnum): ...@@ -47,6 +49,9 @@ class GemmInst(IntEnum):
def is_mfma(self) -> bool: def is_mfma(self) -> bool:
return self == GemmInst.MFMA return self == GemmInst.MFMA
def is_mmac(self) -> bool:
return self == GemmInst.MMAC
def __repr__(self) -> str: def __repr__(self) -> str:
return self.name return self.name
...@@ -184,6 +189,8 @@ class GemmPy(Node, Scriptable): ...@@ -184,6 +189,8 @@ class GemmPy(Node, Scriptable):
return GemmWGMMA return GemmWGMMA
elif gemm_inst.is_tcgen5mma(): elif gemm_inst.is_tcgen5mma():
return GemmTCGEN5 return GemmTCGEN5
elif gemm_inst.is_mmac():
return GemmMMAC
elif gemm_inst.is_mfma(): elif gemm_inst.is_mfma():
return GemmMFMA return GemmMFMA
elif gemm_inst.is_tcgen5mma(): elif gemm_inst.is_tcgen5mma():
......
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,
)
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMMAC(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mmac_emitter = MatrixCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
k_pack=self.k_pack,
)
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mmac_emitter.make_mmac_store_layout(self.C),
}
elif self.is_gemm_sr():
return {
self.A: make_swizzled_layout(self.A),
self.B: mmac_emitter.make_mmac_load_layout(self.B, matrix="B"),
self.C: mmac_emitter.make_mmac_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mmac_emitter.make_mmac_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mmac_emitter.make_mmac_store_layout(self.C),
}
elif self.is_gemm_rr():
return {
self.A: mmac_emitter.make_mmac_load_layout(self.A, matrix="A"),
self.B: mmac_emitter.make_mmac_load_layout(self.B, matrix="B"),
self.C: mmac_emitter.make_mmac_store_layout(self.C),
}
else:
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mmac_emitter = MatrixCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
k_pack=self.k_pack,
)
in_dtype = self.in_dtype
warp_rows = mmac_emitter.warp_rows
warp_cols = mmac_emitter.warp_cols
local_size_a = mmac_emitter.local_size_a
local_size_b = mmac_emitter.local_size_b
block_K = mmac_emitter.chunk
micro_size_k = mmac_emitter.micro_size_k
# Use region for shared-memory operands if available
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mmac ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load A into fragment
mmac_emitter.ldmatrix_a(
A_local,
A_region,
ki,
)
# Load B into fragment
mmac_emitter.ldmatrix_b(
B_local,
B_region,
ki,
)
# Perform Matrix Multiplication
mmac_emitter.mmac(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_srr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mmac ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load A into fragment
mmac_emitter.ldmatrix_a(
A_local,
A_region,
ki,
)
# Perform Matrix Multiplication
mmac_emitter.mmac(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mmac ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load B into fragment
mmac_emitter.ldmatrix_b(
B_local,
B_region,
ki,
)
# Perform Matrix Multiplication
mmac_emitter.mmac(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mmac ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Perform Matrix Multiplication
mmac_emitter.mmac(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment