Unverified Commit aef0a6bb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986)



* remove debug print

* pipeline fix

* use the correct buffer access scope

* rs support

* warp warpgroup_fence_operand

* fix

* fp8 dtype ptx enhance

* mma fix

* TCGEN05 Interface

* tcgen05 support

* rebase

* update

* Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors.

* lint fix

* Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.

* wgmma fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent c85bb3ac
......@@ -45,7 +45,7 @@ public:
Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" &&
scope.tag != ".barrier" && scope.tag != ".descriptor") {
scope.tag != ".barrier" && scope.tag.find(".descriptor") != 0) {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined())
<< "Cannot find memory info of " << scope.to_string();
......
......@@ -88,6 +88,8 @@ private:
Array<Var> new_data_vars;
for (auto buffer : tmem_buffers) {
auto data = buffer->data;
if (var_remap_.count(data))
continue;
auto new_data =
Var(data->name_hint, PointerType(PrimType(tmem_dtype_), "shared"));
var_remap_.Set(data, new_data);
......@@ -107,6 +109,7 @@ private:
buffer->buffer_type);
new_buffers.push_back(new_buffer);
buffer_remap_.Set(buffer, new_buffer);
buffer_data_to_buffer_.Set(new_data, new_buffer);
}
// remove the tmem buffers
......@@ -255,7 +258,15 @@ private:
op->dtype, op->op,
{op->args[0], new_data, op->args[2], op->args[3], op->args[4]});
}
return StmtExprMutator::VisitExpr_(op);
auto expr = StmtExprMutator::VisitExpr_(op);
return expr;
}
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = tvm::ffi::GetRef<Var>(op);
if (var_remap_.count(var)) {
return var_remap_[var];
}
return var;
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
......
......@@ -679,7 +679,7 @@ private:
return !scope.tag.empty() && scope.tag != ".dyn" &&
scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm" && scope.tag != ".var" &&
scope.tag != ".descriptor";
scope.tag.find(".descriptor") != 0;
}
// Allocate entry of node.
......@@ -865,7 +865,7 @@ private:
ICHECK_NE(e->const_nbits, 0U);
MemoryInfo info;
if (e->scope.tag != ".barrier" && e->scope.tag != ".var" &&
e->scope.tag != ".descriptor") {
e->scope.tag.find(".descriptor") != 0) {
info = GetMemoryInfo(e->scope.to_string());
}
uint64_t total_bits = e->const_nbits;
......
......@@ -209,4 +209,3 @@ def test_shuffle_elect_block_leader():
if __name__ == "__main__":
tilelang.testing.main()
# run_get_lane_id()
......@@ -159,8 +159,8 @@ def test_wgmma_marked_async():
def before():
with T.Kernel(1):
A_shared = T.decl_buffer((1,), "float16", scope="shared")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared[0] = T.float16(0)
T.warpgroup_arrive()
......@@ -186,5 +186,43 @@ def test_wgmma_marked_async():
assert order.index("tl.fence_proxy_async") < order.index("tl.ptx_wgmma_ss")
def test_wgmma_after_descriptor():
@T.prim_func
def before():
with T.Kernel(1):
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
C_local = T.decl_buffer((32,), "float16", scope="local")
T.initialize_wgmma_descriptor(desc_a, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc_b, T.uint64(0), 2, 1, 32)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.InjectFenceProxy()(mod)
fence_count = 0
order = []
def visit(node):
nonlocal fence_count
if isinstance(node, tir.Evaluate):
call = node.value
if isinstance(call, tir.Call):
name = getattr(call.op, "name", "")
order.append(name)
if name == "tl.fence_proxy_async":
fence_count += 1
tir.stmt_functor.post_order_visit(mod["main"].body, visit)
assert fence_count >= 1
assert "tl.warpgroup_arrive" in order
assert order.index("tl.fence_proxy_async") < order.index("tl.warpgroup_arrive")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -105,9 +105,15 @@ class TensorCoreIntrinEmitter:
self.local_size_out = (m_dim * n_dim) // warp_size
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype)
self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype)
self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype)
def _get_dtype_abbrv(self, dtype: str) -> str:
try:
return self.dtype_abbrv[dtype]
except KeyError as err:
raise ValueError(f"Unsupported dtype: {dtype}") from err
def _initialize_mma_prefix(self, k_dim: int = 16):
if k_dim == 8:
......
from __future__ import annotations
from enum import IntEnum
import tilelang.language as T
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var
from tilelang import _ffi_api
from tilelang.utils import is_tensor_memory
from tilelang.layout import (
Layout,
make_full_bank_swizzled_layout,
make_half_bank_swizzled_layout,
make_quarter_bank_swizzled_layout,
make_linear_layout,
)
from tvm.runtime import convert
lift = convert
class SwizzleMode(IntEnum):
# SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
NONE = 0
SWIZZLE_128B = 2
SWIZZLE_64B = 4
SWIZZLE_32B = 6
def is_none(self) -> bool:
return self == SwizzleMode.NONE
def is_swizzle_32b(self) -> bool:
return self == SwizzleMode.SWIZZLE_32B
def is_swizzle_64b(self) -> bool:
return self == SwizzleMode.SWIZZLE_64B
def is_swizzle_128b(self) -> bool:
return self == SwizzleMode.SWIZZLE_128B
def swizzle_byte_size(self) -> int:
if self.is_swizzle_32b():
return 32
elif self.is_swizzle_64b():
return 64
elif self.is_swizzle_128b():
return 128
else:
return 1
def swizzle_atom_size(self) -> int:
if self.is_swizzle_32b():
return 32 // 16
elif self.is_swizzle_64b():
return 64 // 16
elif self.is_swizzle_128b():
return 128 // 16
else:
return 1
# derive from MMAIntrinEmitter as some layouts are the same
class TensorCoreIntrinEmitter(MMAIntrinEmitter):
"""
To eliminate Python syntax within TIR Macro.
"""
# should be rewritten to support dynamic k_dim
tcgen05_prefix: str
a_shared_layout: Layout = None
b_shared_layout: Layout = None
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: bool = False,
thread_var: Var | None = None,
):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
num_elems_per_byte, is_m_first, thread_var)
def _assign_a_shared_layout(self, layout: Layout):
self.a_shared_layout = layout
return self
def _assign_b_shared_layout(self, layout: Layout):
self.b_shared_layout = layout
return self
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
# four warps per block
self.warp_rows = warp_row_tiles // m_dim
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8
self.micro_size_x = m_dim
self.micro_size_k = k_dim
def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode:
# same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper
if layout is None or layout.is_equal(make_linear_layout(buffer)):
return SwizzleMode.NONE
elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_32B
elif layout.is_equal(make_half_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_64B
elif layout.is_equal(make_full_bank_swizzled_layout(buffer)):
return SwizzleMode.SWIZZLE_128B
else:
raise ValueError(f"Unsupported swizzle mode: {layout}")
def tcgen05mma(self,
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
mbar,
clear_accum: PrimExpr = False):
if is_tensor_memory(A_buf):
return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum)
accum_dtype = self.accum_dtype
m_dim = self.block_row_warps * self.warp_row_tiles
micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
scale_in_a = 1
scale_in_b = 1
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
a_is_k_major = not self.a_transposed
b_is_k_major = self.b_transposed
a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout)
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
elems_in_bits = DataType(self.a_dtype).bits
elems_in_bytes = elems_in_bits // 8
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_dtype_in_bits = DataType(accum_dtype).bits
meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim)
if len(meta) != 3:
raise ValueError(
f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, "
f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, atom_k = (int(x) for x in meta)
enable_ws = atom_m != 128
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
elems_in_bytes)
a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 *
elems_in_bytes)
if not a_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if a_is_k_major:
a_leading_byte_offset = 16
a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size()
else:
# MN Major
# LBO represents the distance between two atoms along the M dimension
# SBO represents the distance between two atoms along the K dimension
a_m_axis_atoms = m_dim // a_swizzle_atom_elems
if a_m_axis_atoms <= 1:
a_leading_byte_offset = 0
else:
a_leading_byte_offset = k_dim * a_swizzle_mode.swizzle_byte_size()
if a_m_axis_atoms <= 1:
a_stride_byte_offset = 8 * elems_in_bytes * m_dim
else:
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
# MN Major, K * N
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
# for example, if [n, k] where k is 128, we should split it into 2 atoms
# where max specially handles the case when n_dim is 8.
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
instr_desc = self.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
)
# Allocate an instruction descriptor wrapper and initialize it
a_dtype_abbrv = self.a_dtype_abbrv
mask_zero = T.Cast("int32", 0)
mask0 = mask1 = mask2 = mask3 = mask_zero
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf, mbar):
# Allocate SMEM descriptors for A and B
desc_a = T.alloc_tcgen05_smem_desc()
desc_b = T.alloc_tcgen05_smem_desc()
A_ptr = A_buf.access_ptr("r")
B_ptr = B_buf.access_ptr("r")
T.initialize_tcgen05_descriptor(
desc_a,
A_ptr,
int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4),
0,
False,
int(a_swizzle_mode),
)
T.initialize_tcgen05_descriptor(
desc_b,
B_ptr,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4),
0,
False,
int(b_swizzle_mode),
)
for ki in T.serial(0, (k_dim // micro_size_k)):
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
for i in T.serial(m_dim // atom_m):
A_elem_offset = (
ki % ak_atom_size
) * micro_size_k + i * atom_m * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k
A_byte_offset = A_elem_offset * elems_in_bytes
B_byte_offset = B_elem_offset * elems_in_bytes
C_offset = i * atom_n * accum_dtype_in_bits // 32 # 32 bits per tmem bank
T.ptx_tcgen05_mma_ss(
a_dtype_abbrv,
desc_a.data,
A_byte_offset,
desc_b.data,
B_byte_offset,
C_local_buf.data,
C_offset,
instr_desc,
scale_out,
mask0,
mask1,
mask2,
mask3,
enable_ws,
)
T.tcgen05_mma_arrive(mbar)
return _warp_mma(A_buf, B_buf, C_local_buf, mbar)
def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
raise NotImplementedError
def make_mma_store_layout(self, tmem_buf: Buffer) -> Layout:
"""
Create the TCGEN5 tensor-memory layout used to store MMA accumulators.
Parameters
----------
tmem_buf : tir.Buffer
The local buffer representing tensormemory of a mma's output
Returns
-------
Layout
Layout object describing how logical (i, j) coordinates map to the
swizzled tensor-memory offsets required by TCGEN5MMA.
Raises
------
AssertionError
If `tmem_buf` is not detected to be a tensor-memory buffer.
"""
assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)"
if len(tmem_buf.shape) != 2:
raise ValueError(
f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}")
m = int(tmem_buf.shape[0])
n = int(tmem_buf.shape[1])
k = int(self.chunk)
meta = self.get_tcgen5_mma_meta(m, n, k)
if len(meta) != 3:
raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, "
f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}")
atom_m, atom_n, _ = (int(x) for x in meta)
if m % atom_m != 0 or n % atom_n != 0:
raise ValueError(
f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})"
)
def forward(i: PrimExpr, j: PrimExpr):
atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m)
ai = i % atom_m
aj = j % atom_n
if atom_m == 128:
# Layout D
return [
ai,
aj + atom_idx * atom_n,
]
if atom_m == 64:
# Layout E (.ws variant)
half_atom_n = atom_n // 2
return [
(ai // 32) * 32 + ai % 32 + (aj // half_atom_n) * 64,
(aj % half_atom_n) + atom_idx * half_atom_n,
]
if atom_m == 32:
# Layout G
quarter_atom_n = atom_n // 4
return [
ai % 32 + (aj // quarter_atom_n) * 32,
(aj % quarter_atom_n) + atom_idx * quarter_atom_n,
]
raise ValueError(f"Unsupported TCGEN5 atom_m={atom_m}")
return Layout([m, n], forward)
def get_tcgen5_mma_meta(self, m: int, n: int, k: int):
return _ffi_api.get_tcgen5_mma_meta(
int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype))
def get_tcgen5_instr_desc(self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool,
b_is_k_major: bool, scale_in_a: int, scale_in_b: int) -> PrimExpr:
desc = _ffi_api.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
DataType(self.a_dtype),
DataType(self.accum_dtype),
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
)
return lift(desc)
......@@ -164,7 +164,6 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
wgmma_prefix = self.wgmma_prefix
scale_out = not clear_accum
scale_in_a = 1
scale_in_b = 1
......@@ -182,6 +181,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_bits = DataType(accum_dtype).bits
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
# by default, we utilize non-swizzle layout offset
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
......@@ -243,15 +244,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
# TODO(lei): inject warpgroup_fence_operand for C_local_buf
desc_a = T.alloc_descriptor()
desc_b = T.alloc_descriptor()
T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4))
T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
int(a_leading_byte_offset >> 4),
int(a_stride_byte_offset >> 4))
T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)):
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
for i in T.serial(m_dim // 64):
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + (
ki // ak_atom_size
......@@ -267,6 +271,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch()
T.warpgroup_wait(0)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
return _warp_mma(A_buf, B_buf, C_local_buf)
......@@ -286,60 +291,70 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
micro_size_k = self.micro_size_k
k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
wgmma_prefix = self.wgmma_prefix
scale_out = not clear_accum
scale_in_a = 1
scale_in_b = 1
assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}"
elems_in_bytes = DataType(self.a_dtype).bits // 8
a_bits = DataType(self.a_dtype).bits
accum_bits = DataType(accum_dtype).bits
a_regs = ((warp_rows * local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
b_is_k_major = self.b_transposed
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
elems_in_bytes)
b_stride_byte_offset = (8 * k_dim *
elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else
(8 * 8 * elems_in_bytes))
if not b_swizzle_mode.is_none():
# swizzle mode doesn't require LBO/SBO to be 1
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
if b_is_k_major:
b_leading_byte_offset = 16
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
else:
# MN Major
# LBO represents the distance between two atoms along the N dimension
# SBO represents the distance between two atoms along the K dimension
b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
if b_n_axis_atoms <= 1:
b_leading_byte_offset = 0
else:
b_leading_byte_offset = 8 * b_swizzle_mode.swizzle_atom_size() * (
b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
if b_n_axis_atoms <= 1:
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
else:
b_stride_byte_offset = 8 * elems_in_bytes * (
b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
desc_b = T.alloc_descriptor()
T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4),
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)):
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
for i in T.serial(m_dim // 64):
k_dim_offset = ki * micro_size_k
A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = k_dim_offset if b_is_k_major else k_dim_offset * B_buf.shape[-1]
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k
C_offset = i * warp_cols * local_size_out # 4 warps as an unit
T.ptx_wgmma_rs(
accum_dtype,
wgmma_prefix,
self.a_transposed,
not self.b_transposed,
self.b_transposed,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
......@@ -353,6 +368,10 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
scale_in_a,
scale_in_b,
)
T.warpgroup_commit_batch()
T.warpgroup_wait(0)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
return _warp_mma(A_buf, B_buf, C_local_buf)
......
......@@ -257,6 +257,12 @@ class TLCUDASourceWrapper:
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP)
def _lookup_type(self, dtype: str | Any) -> str:
key = dtype if isinstance(dtype, str) else str(dtype)
result = self._TYPE_MAP.get(key)
assert result is not None, f"Unsupported dtype {dtype}"
return result
def is_tma_descriptor_arg(self, arg_name: str) -> bool:
return arg_name in self.prim_func.buffer_map
......@@ -274,10 +280,10 @@ class TLCUDASourceWrapper:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.data.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
"type": self._lookup_type(buffer.dtype) + "* __restrict__",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
......@@ -717,6 +723,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"float16": "ctypes.c_uint16",
"bfloat16": "ctypes.c_uint16",
"float8_e4m3": "ctypes.c_uint8",
"float8_e4m3fn": "ctypes.c_uint8",
"float8_e5m2": "ctypes.c_uint8",
"float64": "ctypes.c_double",
"int64": "ctypes.c_int64",
......@@ -753,7 +760,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"type": "ctypes.c_void_p",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
......@@ -923,6 +930,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"float16": "half_t",
"bfloat16": "bfloat16_t",
"float8_e4m3": "fp8_e4_t",
"float8_e4m3fn": "fp8_e4_t",
"float8_e5m2": "fp8_e5_t",
"float8_e4m3fnuz": "fp8_e4_t",
"e4m3fnuz_float8": "fp8_e4_t",
......@@ -1014,6 +1022,12 @@ class TLCPUSourceWrapper:
self.libpath: str | None = None
self.lib_code: str | None = self.update_lib_code(source)
def _lookup_type(self, dtype: str | Any) -> str:
key = dtype if isinstance(dtype, str) else str(dtype)
result = self._TYPE_MAP.get(key)
assert result is not None, f"Unsupported dtype {dtype}"
return result
def create_call_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
......@@ -1025,10 +1039,10 @@ class TLCPUSourceWrapper:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "*",
"type": self._lookup_type(buffer.dtype) + "*",
})
elif isinstance(param, tvm.tir.Var):
function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]})
function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)})
else:
raise ValueError(
f"Parameter {param} is not in the buffer map of the primary function.")
......
......@@ -46,6 +46,9 @@ from .allocate import (
alloc_tmem, # noqa: F401
alloc_reducer, # noqa: F401
alloc_descriptor, # noqa: F401
alloc_wgmma_desc, # noqa: F401
alloc_tcgen05_smem_desc, # noqa: F401
alloc_tcgen05_instr_desc, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401
......
......@@ -15,7 +15,7 @@ with the appropriate memory scope.
"""
from __future__ import annotations
from typing import overload
from typing import overload, Literal
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
......@@ -218,10 +218,40 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
return reducer
def alloc_descriptor(dtype="uint64", scope="local.descriptor"):
"""Allocate a descriptor buffer for wgmma and utcmma.
DescKind = Literal["wgmma", "tcgen05_smem", "tcgen05_instr"]
def alloc_descriptor(
kind: DescKind = "wgmma",
dtype: str = "uint64",
):
"""Allocate a descriptor buffer for WGMMA and TCGEN5.MMA.
Args:
kind: The descriptor kind, one of "wgmma", "tcgen05" ("utcmma" as alias).
Returns:
T.Buffer: A TVM buffer object allocated as a descriptor
"""
scope = "local.descriptor." + kind
# Buffer naming via `name` is not supported by this TVM builder signature;
# keep parameter for forward-compat, but do not pass it.
return T.alloc_buffer([1], dtype, scope=scope)
def alloc_wgmma_desc(dtype: str = "uint64"):
return alloc_descriptor("wgmma", dtype=dtype)
def alloc_tcgen05_smem_desc(dtype: str = "uint64"):
return alloc_descriptor("tcgen05_smem", dtype=dtype)
def alloc_tcgen05_instruction_desc(dtype: str = "uint32"):
return alloc_descriptor("tcgen05_instr", dtype=dtype)
# Alias: short name consistent with imports
def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
return alloc_tcgen05_instruction_desc(dtype)
......@@ -1894,6 +1894,8 @@ ptx_mma = _dtype_forward(_tir_op.ptx_mma)
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss)
ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs)
ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss)
ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
......@@ -2145,6 +2147,7 @@ __all__ = [
"ptx_mma_sp",
"ptx_wgmma_ss",
"ptx_wgmma_rs",
"ptx_tcgen05_mma_ss",
"ptx_ldmatrix",
"ptx_cp_async",
"ptx_cp_async_bulk",
......
......@@ -5,7 +5,8 @@ from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
from tvm import tir
from tvm import DataType, tir
from tvm.runtime import convert
from typing import Any
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
......@@ -429,6 +430,66 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr,
offset: int | PrimExpr = 0,
num_regs: int | PrimExpr | None = None,
dtype: str | None = None):
"""Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
WGMMA operations by issuing an empty inline assembly barrier on every register.
Args:
buffer_or_ptr: Buffer | PrimExpr
Either a buffer representing the accumulator fragment or a pointer expression.
offset: int | PrimExpr
Element offset from the start of the accumulator fragment.
num_regs: int | PrimExpr | None
Number of 32-bit registers to fence. If None and a Buffer is provided, it will be
derived from the buffer shape and dtype.
dtype: str | None
Data type string of the accumulator elements. Required when passing a pointer.
Returns:
tir.Call: A handle to the warpgroup fence operation.
"""
if isinstance(buffer_or_ptr, BufferLoad):
raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.")
if isinstance(buffer_or_ptr, Buffer):
data_ptr = buffer_or_ptr.data
inferred_dtype = buffer_or_ptr.dtype
if dtype is not None and dtype != inferred_dtype:
raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.")
dtype = inferred_dtype
if num_regs is None:
total_elems = 1
for dim in buffer_or_ptr.shape:
if isinstance(dim, tir.IntImm):
total_elems *= int(dim)
else:
raise ValueError(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic.")
bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32
else:
data_ptr = buffer_or_ptr
if dtype is None:
raise ValueError("dtype must be provided when passing a pointer expression.")
if num_regs is None:
raise ValueError("num_regs must be provided when passing a pointer expression.")
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.warpgroup_fence_operand"),
dtype,
data_ptr,
convert(offset),
convert(num_regs),
))
def wait_wgmma(id: int):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
......@@ -537,38 +598,68 @@ def sync_grid():
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))
def initialize_descriptor(descriptor: Buffer,
start_address: PrimExpr,
layout_type_: int = 0,
leading_byte_offset: int = 0,
stride_byte_offset: int = 0) -> PrimExpr:
"""
Initialize a memory descriptor with the given parameters.
def initialize_wgmma_descriptor(
descriptor: Buffer,
start_address: PrimExpr,
layout_type_: int = 0,
leading_byte_offset: int = 0,
stride_byte_offset: int = 0,
) -> PrimExpr:
"""Initialize a WGMMA/UTCMMA shared-memory descriptor."""
Parameters:
descriptor (Buffer): The memory descriptor to initialize.
start_address (PrimExpr): The starting address of the memory region.
layout_type_ (int, optional): Layout type identifier. Defaults to 0.
leading_byte_offset (int, optional): Leading byte offset. Defaults to 0.
stride_byte_offset (int, optional): Stride byte offset. Defaults to 0.
if not isinstance(descriptor, (BufferLoad, Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
Returns:
PrimExpr: A handle representing the initialized descriptor.
"""
if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
descriptor, [0])
return evaluate(
tir.call_intrin(
"handle",
tir.op.Op.get("tl.initialize_wgmma_descriptor"),
descriptor,
start_address,
layout_type_,
int(leading_byte_offset),
int(stride_byte_offset),
))
def initialize_tcgen05_descriptor(
descriptor: Buffer,
start_address: PrimExpr,
leading_byte_offset: int,
stride_byte_offset: int,
base_offset: int = 0,
leading_is_absolute: bool = False,
swizzle_mode: int = 0,
) -> PrimExpr:
"""Initialize a TCGEN05 shared-memory descriptor."""
if not isinstance(descriptor, (BufferLoad, Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
descriptor, [0])
return evaluate(
tir.call_intrin("handle", tir.op.Op.get("tl.initialize_descriptor"), descriptor,
start_address, layout_type_, int(leading_byte_offset),
int(stride_byte_offset)))
tir.call_intrin(
"handle",
tir.op.Op.get("tl.initialize_tcgen05_descriptor"),
descriptor,
start_address,
int(leading_byte_offset),
int(stride_byte_offset),
int(base_offset),
tir.IntImm("int32", 1 if leading_is_absolute else 0),
int(swizzle_mode),
))
def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr:
......@@ -606,3 +697,14 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
def tcgen05_mma_arrive(mbar_ptr):
"""Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
Parameters
----------
mbar_ptr : PrimExpr
Pointer to the mbarrier object in shared memory (e.g., Barrier*).
"""
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
......@@ -222,6 +222,7 @@ def gemm_v2(
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
mbar: tir.Buffer | None = None,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
......@@ -238,6 +239,7 @@ def gemm_v2(
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization
Returns:
tir.Call: A handle to the GEMM operation
......@@ -262,6 +264,7 @@ def gemm_v2(
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
......@@ -404,6 +407,8 @@ def gemm_v2(
Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32")
C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0]
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_py"),
......@@ -423,4 +428,7 @@ def gemm_v2(
offset_b,
k_pack,
wg_wait,
mbarptr,
C_coords[0],
C_coords[1],
)
......@@ -104,6 +104,13 @@ def unroll(start: PrimExpr,
res : frame.ForFrame
The ForFrame.
"""
# Ensure annotations has {"pragma_unroll_explicit": True} by default
if annotations is None:
annotations = {"pragma_unroll_explicit": False}
else:
# Add "pragma_unroll_explicit": True if not already present
annotations = dict(annotations)
annotations.setdefault("pragma_unroll_explicit", False)
return _ir.unroll(start=start, stop=stop, annotations=annotations)
......@@ -294,6 +301,8 @@ ptx_mma = _dtype_forward(_tir_op.ptx_mma)
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss)
ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs)
ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss)
ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
......
......@@ -1107,7 +1107,6 @@ def ptx_wgmma_ss(
def ptx_wgmma_rs(
dtype,
wgmma_prefix,
a_is_k_major,
b_is_k_major,
a_dtype_abbrv,
b_dtype_abbrv,
......@@ -1127,7 +1126,6 @@ def ptx_wgmma_rs(
dtype,
_tvm_op.Op.get("tl.ptx_wgmma_rs"),
wgmma_prefix,
a_is_k_major,
b_is_k_major,
a_dtype_abbrv,
b_dtype_abbrv,
......@@ -1144,6 +1142,115 @@ def ptx_wgmma_rs(
)
def ptx_tcgen05_mma_ss(
kind_dtype,
desc_a,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
enable_ws=False,
ws=None,
warp_specialized=None,
variant=None,
):
"""TVM intrinsic for tcgen05.mma shared-memory × shared-memory instructions.
Expects 13 or 14 positional arguments:
(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset,
desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]).
Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`.
Alternatively, use `variant="ws"` (or "default").
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
"""
# Aliases precedence: if either `ws` or `warp_specialized` is provided, they override enable_ws
if ws is not None:
enable_ws = bool(ws)
if warp_specialized is not None:
enable_ws = bool(warp_specialized)
if variant is not None:
if isinstance(variant, str):
v = variant.lower()
if v in ("ws", "warp_specialized", "warp-specialized"):
enable_ws = True
elif v in ("default", "std", "ss"):
enable_ws = False
else:
raise ValueError(f"ptx_tcgen05_mma_ss: unknown variant: {variant}")
else:
# Treat non-string as truthy flag
enable_ws = bool(variant)
return call_intrin(
"handle",
_tvm_op.Op.get("tl.ptx_tcgen05_mma_ss"),
kind_dtype,
desc_a,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
enable_ws,
)
def ptx_tcgen05_mma_ts(
kind_dtype,
A_ptr,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
):
"""TVM intrinsic for tcgen05.mma tensor-memory × shared-memory instructions.
Expects 13 positional arguments:
(kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset,
desc_val, scale_out, mask0, mask1, mask2, mask3).
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
"""
return call_intrin(
"handle",
_tvm_op.Op.get("tl.ptx_tcgen05_mma_ts"),
kind_dtype,
A_ptr,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
)
def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
......
......@@ -6,6 +6,7 @@ from .fragment import Fragment # noqa: F401
from .swizzle import (
make_swizzled_layout, # noqa: F401
make_wgmma_swizzled_layout, # noqa: F401
make_tcgen05mma_swizzled_layout, # noqa: F401
make_full_bank_swizzled_layout, # noqa: F401
make_half_bank_swizzled_layout, # noqa: F401
make_quarter_bank_swizzled_layout, # noqa: F401
......
......@@ -34,6 +34,22 @@ def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
)
# for TCGEN05MMA Intrinsics
def make_tcgen05mma_swizzled_layout(buffer: tvm.tir.Buffer,
continuity: int = None,
k_major: bool = True):
assert len(buffer.shape) == 2
if continuity is None:
continuity = int(buffer.shape[1])
return _ffi_api.make_tcgen05mma_swizzled_layout(
int(buffer.shape[0]),
int(buffer.shape[1]),
continuity,
int(tvm.DataType(buffer.dtype).bits),
k_major,
)
# swizzle 128B
# args: buffer or (stride, continuous, element_size)
def make_full_bank_swizzled_layout(*args):
......
......@@ -8,6 +8,7 @@ import tvm_ffi
from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA
from .gemm_wgmma import GemmWGMMA
from .gemm_tcgen05 import GemmTCGEN5
from .gemm_mfma import GemmMFMA
from tilelang import _ffi_api
......@@ -45,6 +46,9 @@ class GemmInst(IntEnum):
def is_mfma(self) -> bool:
return self == GemmInst.MFMA
def __repr__(self) -> str:
return self.name
@tvm_ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable):
......@@ -119,6 +123,8 @@ class GemmPy(Node, Scriptable):
return GemmMMA
elif gemm_inst.is_wgmma():
return GemmWGMMA
elif gemm_inst.is_tcgen5mma():
return GemmTCGEN5
elif gemm_inst.is_mfma():
return GemmMFMA
elif gemm_inst.is_tcgen5mma():
......
......@@ -118,3 +118,15 @@ class GemmBase:
@property
def policy(self) -> GemmWarpPolicy:
return self.gemm_node.policy
@property
def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32"))
@property
def C_coords(self):
coords = getattr(self.gemm_node, "C_coords", None)
if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32")
return [zero, zero]
return [coords[i] for i in range(len(coords))]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment