Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -104,6 +104,13 @@ def unroll(start: PrimExpr,
res : frame.ForFrame
The ForFrame.
"""
# Ensure annotations has {"pragma_unroll_explicit": True} by default
if annotations is None:
annotations = {"pragma_unroll_explicit": False}
else:
# Add "pragma_unroll_explicit": True if not already present
annotations = dict(annotations)
annotations.setdefault("pragma_unroll_explicit", False)
return _ir.unroll(start=start, stop=stop, annotations=annotations)
......@@ -294,6 +301,8 @@ ptx_mma = _dtype_forward(_tir_op.ptx_mma)
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss)
ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs)
ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss)
ptx_tcgen05_mma_ts = _dtype_forward(_tir_op.ptx_tcgen05_mma_ts)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
......
from typing import TypeVar, Literal
from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm
_T = TypeVar('_T')
def abs(x: _T, span: Span | None=None) -> _T: ...
def acos(x: _T) -> _T: ...
def acosh(x: _T) -> _T: ...
def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ...
def asin(x: _T) -> _T: ...
def asinh(x: _T) -> _T: ...
def atan(x: _T) -> _T: ...
def atan2(x1: _T, x2: _T) -> _T: ...
def atanh(x: _T) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_not(x: _T, span: Span | None=None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ...
def ceil(x: _T, span: Span | None=None) -> _T: ...
def clz(x: _T) -> _T: ...
def copysign(x1: _T, x2: _T) -> _T: ...
def cos(x: _T) -> _T: ...
def cosh(x: _T) -> _T: ...
def erf(x: _T) -> _T: ...
def exp(x: _T) -> _T: ...
def exp2(x: _T) -> _T: ...
def exp10(x: _T) -> _T: ...
def floor(x: _T, span: Span | None=None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def fmod(x: _T, y: _T) -> _T: ...
def hypot(x1: _T, x2: _T) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ...
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
def isfinite(x: _T, span: Span | None=None) -> _T: ...
def isinf(x: _T, span: Span | None=None) -> _T: ...
def isnan(x: _T, span: Span | None=None) -> _T: ...
def isnullptr(x: _T, span: Span | None=None) -> _T: ...
def ldexp(x1: _T, x2: _T) -> _T: ...
def likely(cond: _T, span: Span | None=None) -> _T: ...
def log(x: _T) -> _T: ...
def log1p(x: _T) -> _T: ...
def log2(x: _T) -> _T: ...
def log10(x: _T) -> _T: ...
def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None=None) -> _T: ...
def nextafter(x1: _T, x2: _T) -> _T: ...
def popcount(x: _T) -> _T: ...
def pow(x: _T, y: _T, span: Span | None=None) -> _T: ...
def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ...
def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ...
def ret(val: _T) -> _T: ...
def round(x: _T, span: Span | None=None) -> _T: ...
def rsqrt(x: _T) -> _T: ...
def shift_left(x: _T, y: _T, span=None) -> _T: ...
def shift_right(x: _T, y: _T, span=None) -> _T: ...
def sigmoid(x: _T) -> _T: ...
def sin(x: _T) -> _T: ...
def sinh(x: _T) -> _T: ...
def sqrt(x: _T) -> _T: ...
def tan(x: _T) -> _T: ...
def tanh(x: _T) -> _T: ...
def trunc(x: _T, span: Span | None=None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ...
def tvm_throw_last_error() -> _T: ...
def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ...
def tvm_stack_make_shape(*args) -> _T: ...
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ...
def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ...
def call_packed(*args, span=None) -> _T: ...
def call_cpacked(*args, span=None) -> _T: ...
def call_packed_lowered(*args, span=None) -> _T: ...
def call_cpacked_lowered(*args, span=None) -> _T: ...
def tvm_tuple(*value) -> _T: ...
def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
def tvm_thread_invariant(cond: _T) -> _T: ...
def tvm_thread_allreduce(*freduce_args) -> _T: ...
def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ...
def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def ptx_wait_group(num: int) -> PrimExpr: ...
def ptx_commit_group() -> _T: ...
def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ...
def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ...
def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ...
def create_barriers(barrier_count: int) -> PrimExpr: ...
def assume(cond: _T=None) -> _T: ...
def undef() -> _T: ...
def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ...
def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ...
def start_profile_intrinsic(id: int) -> PrimExpr: ...
def end_profile_intrinsic(id: int) -> PrimExpr: ...
def anylist_getitem(list_handle, index) -> PrimExpr: ...
def anylist_resetitem(list_handle, index) -> PrimExpr: ...
def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ...
def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ...
def vscale() -> _T: ...
......@@ -1107,7 +1107,6 @@ def ptx_wgmma_ss(
def ptx_wgmma_rs(
dtype,
wgmma_prefix,
a_is_k_major,
b_is_k_major,
a_dtype_abbrv,
b_dtype_abbrv,
......@@ -1127,7 +1126,6 @@ def ptx_wgmma_rs(
dtype,
_tvm_op.Op.get("tl.ptx_wgmma_rs"),
wgmma_prefix,
a_is_k_major,
b_is_k_major,
a_dtype_abbrv,
b_dtype_abbrv,
......@@ -1144,6 +1142,115 @@ def ptx_wgmma_rs(
)
def ptx_tcgen05_mma_ss(
kind_dtype,
desc_a,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
enable_ws=False,
ws=None,
warp_specialized=None,
variant=None,
):
"""TVM intrinsic for tcgen05.mma shared-memory × shared-memory instructions.
Expects 13 or 14 positional arguments:
(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset,
desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]).
Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`.
Alternatively, use `variant="ws"` (or "default").
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
"""
# Aliases precedence: if either `ws` or `warp_specialized` is provided, they override enable_ws
if ws is not None:
enable_ws = bool(ws)
if warp_specialized is not None:
enable_ws = bool(warp_specialized)
if variant is not None:
if isinstance(variant, str):
v = variant.lower()
if v in ("ws", "warp_specialized", "warp-specialized"):
enable_ws = True
elif v in ("default", "std", "ss"):
enable_ws = False
else:
raise ValueError(f"ptx_tcgen05_mma_ss: unknown variant: {variant}")
else:
# Treat non-string as truthy flag
enable_ws = bool(variant)
return call_intrin(
"handle",
_tvm_op.Op.get("tl.ptx_tcgen05_mma_ss"),
kind_dtype,
desc_a,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
enable_ws,
)
def ptx_tcgen05_mma_ts(
kind_dtype,
A_ptr,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
):
"""TVM intrinsic for tcgen05.mma tensor-memory × shared-memory instructions.
Expects 13 positional arguments:
(kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset,
desc_val, scale_out, mask0, mask1, mask2, mask3).
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
"""
return call_intrin(
"handle",
_tvm_op.Op.get("tl.ptx_tcgen05_mma_ts"),
kind_dtype,
A_ptr,
A_offset,
desc_b,
B_offset,
C_ptr,
C_offset,
desc_val,
scale_out,
mask0,
mask1,
mask2,
mask3,
)
def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
"""TVM intrinsic for storing the result of PTX MMA into a destination pointer
......@@ -1893,7 +2000,7 @@ def infinity(dtype: str, span: Span | None = None) -> Any:
value : tvm.Expr
The infinity value of dtype.
"""
return _tvm_op.infinity(dtype, span)
return call_intrin(dtype, _tvm_op.Op.get("tl.infinity"), dtype, span=span)
def reinterpret(dtype, value, span: Span | None = None) -> Any:
......
......@@ -85,7 +85,14 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
# Clamp extents element-wise so that the produced region respects the
# requested copy/fill extent, supporting dynamic PrimExpr via tir.min.
clamped_extents = [
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i]
for i in range(len(region_extents))
]
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents)
def index_to_coordinates(index, shape) -> list[PrimExpr]:
......
from .builder import prim_func, macro, PrimFunc # noqa: F401
from .dtypes import *
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import Callable, Generic, Any, Literal, TypeVar
from contextlib import AbstractContextManager
from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
import inspect
# from .utils import get_ast, get_compiled_object
from . import utils
_span_attrs = ['lineno', 'col_offset', 'end_lineno', 'end_col_offset']
def ast_has_span(ast: ast.AST) -> bool:
return all(hasattr(ast, attr) for attr in _span_attrs)
def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]):
if not ast_has_span(ast):
return
for attr, value in zip(_span_attrs, span):
setattr(ast, attr, value)
class QuoteVisitor(ast.NodeTransformer):
def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None):
self.names = names
self.passes = passes or []
self.span = span
def generic_visit(self, node: ast.AST):
if self.span is not None:
ast_set_span(node, self.span)
return super().generic_visit(node)
def visit_Name(self, node: ast.Name) -> Any:
if node.id in self.names:
return self.names[node.id]
else:
return node
def visit_Pass(self, node: ast.Pass) -> Any:
item = self.passes.pop(0)
return item if item else node
def quote(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> list[ast.AST]:
tree = ast.parse(expr)
if isinstance(span, ast.AST):
span = ast_get_span(span)
tree = QuoteVisitor(kws, passes, span).visit(tree)
return tree.body
def quote1(expr: str, *, passes: list[Any] | None = None, span=None, **kws) -> ast.AST:
res = quote(expr, passes=passes, span=span, **kws)
assert len(res) == 1
return res[0]
def quote_expr(expr: str, **kws) -> ast.expr:
res = quote1(expr, **kws)
assert isinstance(res, ast.Expr)
return res.value
Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift',
'BitOr', 'BitXor', 'BitAnd', 'FloorDiv']
BoolOp = Literal['And', 'Or']
def get_operator_name(operator: ast.operator) -> Operator:
return operator.__class__.__name__
def get_boolop_name(boolop: ast.boolop) -> BoolOp:
return boolop.__class__.__name__
_T = TypeVar('_T')
def eval_op(op: Operator, left: Any, right: Any) -> Any:
if op == 'Add':
return left + right
if op == 'Sub':
return left - right
if op == 'Mult':
return left * right
if op == 'MatMult':
return left @ right
if op == 'Div':
return left / right
if op == 'Mod':
return left % right
if op == 'Pow':
return left**right
if op == 'LShift':
return left << right
if op == 'RShift':
return left >> right
if op == 'BitOr':
return left | right
if op == 'BitXor':
return left ^ right
if op == 'BitAnd':
return left & right
if op == 'FloorDiv':
return left // right
raise ValueError(f'Unknown operator: {op}')
def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any:
if op == 'Add':
left[sl] += right
return left
if op == 'Sub':
left[sl] -= right
return left
if op == 'Mult':
left[sl] *= right
return left
if op == 'MatMult':
left[sl] @= right
return left
if op == 'Div':
left[sl] /= right
return left
if op == 'Mod':
left[sl] %= right
return left
if op == 'Pow':
left[sl] **= right
return left
if op == 'LShift':
left[sl] <<= right
return left
if op == 'RShift':
left[sl] >>= right
return left
if op == 'BitOr':
left[sl] |= right
return left
if op == 'BitXor':
left[sl] ^= right
return left
if op == 'BitAnd':
left[sl] &= right
return left
if op == 'FloorDiv':
left[sl] //= right
return left
raise ValueError(f'Unknown operator: {op}')
class _empty:
...
class BaseBuilder:
empty = _empty
def get_parent_locals(self):
return inspect.currentframe().f_back.f_back.f_locals
def ctx_if(self, cond) -> Iterable[_T]:
yield cond
def ctx_then(self, val: _T) -> Iterable[None]:
if val:
yield
def ctx_else(self, val: _T) -> Iterable[None]:
if not val:
yield
def eval(self, val: Any): # noqa: B027
pass
def ctx_for(self, range: Iterable[Any]) -> Iterable[Any]:
return range
def ctx_continue(self) -> bool:
return True
def ctx_break(self) -> bool:
return True
def ctx_while(self, cond: Callable[[], Any]) -> Iterable[None]:
while cond():
yield
def bind(self, name: str, value: Any, annot: Any = empty) -> Any:
return value
def unwrap_value(self, value):
return value
def assign_slice(self, lval: Any, sl: slice, value: Any, annot: Any = empty):
lval[sl] = value
def aug_assign(self, op: Operator, target: Any, aug_value: Any) -> Any:
return eval_op(op, target, aug_value)
def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any):
eval_aug_assign(op, target, sl, aug_value)
def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any]) -> Any:
if op == 'And':
return left and right()
if op == 'Or':
return left or right()
raise ValueError(f'Unknown boolop: {op}')
def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any:
return then() if cond else otherwise()
def ret(self, value: Any) -> Any:
return value
def ctx_with(self, ctx: AbstractContextManager[Any]) -> AbstractContextManager[Any]:
return ctx
def assert_expr(self, cond: Any, msg: Any):
assert cond, msg
def rval(self, name: str, value: Any):
return value
def arg(self, name: str, value: Any):
return value
def override(self, name: str):
return globals()[name]
class DSLMutator(ast.NodeTransformer):
def __init__(self):
self.tmp_counter = 0
def get_tmp(self) -> str:
name = f"__{self.tmp_counter}"
self.tmp_counter += 1
return name
def visit_If(self, node: ast.If):
node = self.generic_visit(node)
br = self.get_tmp()
if len(node.orelse) == 0:
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
" pass\n",
cond=node.test,
passes=[node.body],
span=node,
)
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
f" pass\n"
f" for _ in __tb.ctx_else({br}):\n"
f" pass\n",
cond=node.test,
passes=[node.body, node.orelse],
span=node,
)
def visit_Expr(self, node: ast.Expr):
node = self.generic_visit(node)
return quote("__tb.eval(value)", value=node.value, span=node)
def _parse_names(self, target: ast.expr):
if isinstance(target, ast.Name):
return f"'{target.id}'"
elif isinstance(target, ast.Tuple):
return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)")
else:
s = ast.unparse(target)
raise NotImplementedError(f"Unsupported for target `{s}`")
def visit_For(self, node: ast.For):
node = self.generic_visit(node)
tmp = self.get_tmp()
# names = self._parse_names(node.target)
var = ast.Name(tmp, ctx=ast.Load())
ast_set_span(var, ast_get_span(node.target))
stmts = self._emit_assign_target(node.target, var)
return quote(
f"for {tmp} in __tb.ctx_for(range):\n"
" pass\n",
target=node.target,
range=node.iter,
passes=[stmts + node.body],
span=node,
)
def visit_Continue(self, node: ast.Continue):
node = self.generic_visit(node)
return quote("if __tb.ctx_continue(): continue", span=node)
def visit_Break(self, node: ast.Break):
node = self.generic_visit(node)
return quote("if __tb.ctx_break(): break", span=node)
def _emit_assign_target(self,
target: ast.expr,
rval: ast.expr,
annot: ast.expr = None) -> list[ast.AST]:
if isinstance(target, ast.Name):
if annot is None:
return quote(
f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target)
else:
return quote(
f'name = __tb.bind("{target.id}", value, annot)',
name=target,
value=rval,
annot=annot,
span=target)
elif isinstance(target, ast.Attribute):
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
elif isinstance(target, ast.Subscript):
if annot is None:
return quote(
"__tb.assign_slice(lval, slice, value)",
lval=target.value,
slice=target.slice,
value=rval,
span=target,
)
else:
return quote(
"__tb.assign_slice(lval, slice, value, annot)",
lval=target.value,
slice=target.slice,
value=rval,
annot=annot,
span=target,
)
else:
# flatten nested tuple into a list of (tmp_name, target)
unpacked = []
def _visit_target(target: ast.expr) -> str:
if isinstance(target, (ast.Name, ast.Subscript)):
tmp = self.get_tmp()
unpacked.append((tmp, target))
res = ast.Name(id=tmp, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
elif isinstance(target, ast.Tuple):
elts = [_visit_target(elt) for elt in target.elts]
res = ast.Tuple(elts=elts, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
else:
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
unpack_stmt = ast.Assign(
targets=[_visit_target(target)],
value=quote_expr('__tb.unwrap_value(rval)', rval=rval, span=rval))
ast_set_span(unpack_stmt, ast_get_span(target))
stmts = [unpack_stmt]
bind_lvals = []
bind_rvals = []
def flush_binds():
if bind_lvals:
stmts.append(
quote1(f'{", ".join(bind_lvals)}, = {", ".join(bind_rvals)},', span=target))
bind_lvals.clear()
bind_rvals.clear()
# the following code generate two phase binding to support swap like semantics
# for example:
# a, b = b, a
# 1 phase:
# _tmp_0, _tmp_1 = b, a
# => _tmp_0: T.int32 = b
# => _tmp_1: T.int32 = a
# 2 phase:
# a, b = _tmp_0, _tmp_1
# => a = _tmp_0 => a[0] = _tmp_0
# => b = _tmp_1 => b[0] = _tmp_1
# 1 phase: _tmp_0, _tmp_1 = __tb.bind('_', a), __tb.bind('_', b)
for tmp, _target in unpacked:
bind_lvals.append(tmp)
bind_rvals.append(f'__tb.bind("_", {tmp})')
flush_binds()
# 2 phase: a, b = __tb.bind('a', _tmp_0), __tb.bind('b', _tmp_1)
for tmp, target in unpacked:
if isinstance(target, ast.Name):
bind_lvals.append(target.id)
bind_rvals.append(f'__tb.bind("{target.id}", {tmp})')
elif isinstance(target, ast.Subscript):
flush_binds()
stmts.append(
quote1(
f'__tb.assign_slice(lval, slice, {tmp})',
lval=target.value,
slice=target.slice,
span=target))
else:
s = ast.unparse(target)
raise NotImplementedError(f'Unsupported target: {s}')
flush_binds()
return stmts
def visit_Assign(self, node: ast.Assign) -> list[ast.AST]:
node = self.generic_visit(node)
rval = node.value
if len(node.targets) == 1:
return self._emit_assign_target(node.targets[0], rval)
else:
tmp_name = self.get_tmp()
tmp_store = ast.Name(tmp_name, ctx=ast.Store())
tmp_load = ast.Name(tmp_name, ctx=ast.Load())
ast_set_span(tmp_store, node.targets[0])
ast_set_span(tmp_load, node.targets[0])
stmt = self._emit_assign_target(tmp_store, rval)
for target in node.targets:
stmt.extend(self._emit_assign_target(target, tmp_load))
return stmt
def visit_AugAssign(self, node: ast.AugAssign) -> list[ast.AST]:
node = self.generic_visit(node)
target, rval = node.target, node.value
op = get_operator_name(node.op)
if isinstance(target, ast.Name):
return quote(
f"name = __tb.aug_assign('{op}', {target.id}, value)",
name=target,
value=rval,
span=node)
elif isinstance(target, ast.Subscript):
return quote(
f"__tb.aug_assign_slice('{op}', lval, slice, value)",
lval=target.value,
slice=target.slice,
value=rval,
span=node,
)
else:
return node
def visit_AnnAssign(self, node: ast.AnnAssign):
node = self.generic_visit(node)
rval = node.value or quote_expr('__tb.empty', span=node, annot=node)
return self._emit_assign_target(node.target, rval, annot=node.annotation)
def visit_While(self, node):
node = self.generic_visit(node)
return quote1(
"for _ in __tb.ctx_while(lambda: cond):\n pass",
cond=node.test,
passes=[node.body],
span=node)
def visit_FunctionDef(self, node: ast.FunctionDef):
node = self.generic_visit(node)
all_args = node.args.posonlyargs + node.args.args
if node.args.vararg is not None:
all_args += node.args.vararg
all_args += node.args.kwonlyargs
stmts = []
for arg in all_args:
name = arg.arg
if arg.annotation is not None:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
else:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
arg.annotation = None
stmts.append(arg_stmt)
node.body = stmts + node.body
node.decorator_list.clear()
return quote1(
f"def {node.name}(__tb):\n"
" range = __tb.override('range')\n"
" pass\n"
f" return {node.name}",
passes=[node],
)
def visit_BoolOp(self, node: ast.BoolOp):
node = self.generic_visit(node)
op_name = get_boolop_name(node.op)
last = node.values[-1]
for i in reversed(range(len(node.values) - 1)):
last = quote_expr(
expr=f"__tb.boolop('{op_name}', left, lambda: right)",
left=node.values[i],
right=last,
span=node,
)
return last
def visit_Compare(self, node: ast.Compare) -> ast.expr:
node = self.generic_visit(node)
left = node.left
split = []
for op, comp in zip(node.ops, node.comparators):
cmp = ast.Compare(left=left, ops=[op], comparators=[comp])
ast_set_span(cmp, ast_get_span(node))
split.append(cmp)
left = comp
last = split[-1]
for i in reversed(range(len(split) - 1)):
last = quote_expr(
"__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node)
return last
def visit_IfExp(self, node: ast.IfExp) -> ast.Expr:
node = self.generic_visit(node)
return quote_expr(
'__tb.ifexp(cond, lambda: then, lambda: otherwise)',
cond=node.test,
then=node.body,
otherwise=node.orelse,
span=node)
def visit_Return(self, node: ast.Return):
node = self.generic_visit(node)
return quote("return __tb.ret(value)", value=node.value, span=node)
def visit_With(self, node: ast.With):
node = self.generic_visit(node)
for expr in node.items:
expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr)
return node
def visit_Assert(self, node: ast.Assert):
node = self.generic_visit(node)
return quote("__tb.assert_expr(cond, msg)", cond=node.test, msg=node.msg, span=node)
def visit_Name(self, node: ast.Name):
if isinstance(node.ctx, ast.Load):
return quote_expr(f"__tb.rval('{node.id}', node)", node=node, span=node)
return node
_P = ParamSpec('_P')
@dataclass
class IRGenerator(Generic[_P, _T]):
gen: Callable[[BaseBuilder], Callable[_P, _T]]
source: str
def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]:
"""
Transform a Python function into an IR (Intermediate Representation) generator.
This function takes a regular Python function and performs AST (Abstract Syntax Tree)
transformation to create an IRGenerator that can be used for code generation purposes.
Args:
func (Callable[_P, _T]): The Python function to be transformed. This should be a
callable that will be analyzed and mutated at the AST level. The function's
signature is preserved through generic type parameters _P (parameters) and
_T (return type).
Returns:
IRGenerator[_P, _T]: An IRGenerator instance wrapping the transformed function.
The generator contains:
- gen: The compiled and mutated version of the original function
- source: The unparsed source code of the transformed AST as a string
Example:
>>> @mutate
... def my_function(x: int) -> int:
... return x * 2
>>> # my_function is now an IRGenerator that can be used for code generation
Note:
- The original function's closure variables and captured context are preserved
- The transformation is performed at compile-time through AST manipulation
- The returned IRGenerator maintains type information from the original function
"""
tree = utils.get_ast(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func)
tree = DSLMutator().visit(tree)
fn = utils.get_compiled_object(tree, func.__name__, filename,
utils.inspect_function_capture(func))
return IRGenerator(gen=fn, source=ast.unparse(tree))
from __future__ import annotations
from contextlib import contextmanager, AbstractContextManager
from dataclasses import dataclass
import inspect
from tilelang.language.kernel import KernelLaunchFrame
from tvm_ffi.container import Map
from tvm.ir.base import Span
from .ast import BaseBuilder, IRGenerator, eval_op, mutate
import tvm
from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union
# Python 3.9 compatibility for ParamSpec and Self
try:
from typing import ParamSpec, Self
except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec
from typing_extensions import ParamSpec, Self
from . import dtypes as dt
import threading
import logging
logger = logging.getLogger(__name__)
def unwrap_expr(expr) -> PrimExpr | int | float:
'''
unwrap expr and convert it into PrimExpr like
'''
if isinstance(expr, tir.meta_var):
expr = expr.value
elif isinstance(expr, Buffer) and expr.scope() == 'local.var':
expr = tir.BufferLoad(expr, indices=[0])
elif isinstance(expr, (EqualOp, NotEqualOp)):
expr = expr.asobject()
return expr
def unwrap_cond(expr):
'''
unwrap expr and convert to bool condition
'''
expr = unwrap_expr(expr)
if isinstance(expr, (IntImm, FloatImm, StringImm)):
return bool(expr.value)
elif isinstance(expr, PrimExpr):
return expr
elif isinstance(expr, Buffer):
raise TypeError(f"Buffer `{expr}` cannot be used as condition directly.")
elif isinstance(expr, (int, bool)) or expr is None:
return bool(expr)
else:
logger.warning(
f"Python expression `{expr}` is used as condition in TileLang, \n"
"this is treated as a constant expression. ",
stack_info=True,
stacklevel=3)
return bool(expr)
thread_local_storage = threading.local()
class Frame:
'''
Frame are virtual context managers used in frontend only
They do not have any runtime representation in the generated TIR.
'''
def __enter__(self):
...
def __exit__(self, exc_type, exc_value, traceback):
...
class MacroFrame(Frame):
...
class BoolOpFrame(Frame):
...
class ConstIfFrame(Frame):
...
class BlockFrame(Frame):
...
class ContinueFrame(Frame):
...
class BreakFrame(Frame):
...
@dataclass
class SerialForWithStep:
start: PrimExpr
stop: PrimExpr
step: PrimExpr
annotations: dict[str, Any] | None = None
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
AnyFrame = Union[tir.frame.IRBuilderFrame, Frame]
TIR_CONTROL_FRAME = (
tir.frame.WhileFrame,
tir.frame.ForFrame,
tir.frame.IfFrame,
tir.frame.PrimFuncFrame,
)
TIR_VAR_SCOPE_FRAME = (
tir.frame.WhileFrame,
tir.frame.ForFrame,
tir.frame.IfFrame,
tir.frame.PrimFuncFrame,
MacroFrame,
KernelLaunchFrame,
)
def is_var(v: Any) -> bool:
return isinstance(v, Buffer) and v.scope() == 'local.var'
class Builder(BaseBuilder):
def __init__(self):
self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder()
self.name_inside_frame: dict[str, AnyFrame] = {}
@classmethod
def current(cls) -> Self:
builder = thread_local_storage.builder
assert builder is not None, "No active Builder found in the current thread."
return builder
@contextmanager
def prim_func(self, name):
thread_local_storage.builder = self
with self.ir_builder, self.with_frame(tir.prim_func()):
tir.func_name(name)
yield
@contextmanager
def macro(self, name=None):
if self.find_frame_idx(BoolOpFrame) is not None:
raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
save = self.name_inside_frame
self.name_inside_frame = {}
with self.with_frame(MacroFrame()):
yield
self.name_inside_frame = save
def get(self):
return self.ir_builder.get()
def find_frame_idx(self, frame: type | tuple[type, ...], start=0) -> int | None:
for idx in reversed(range(start, len(self.frames))):
f = self.frames[idx]
if isinstance(f, frame):
return idx
def enter_frame(self, frame: AbstractContextManager[Any]):
self.frames.append(frame)
return frame.__enter__()
def check_continue_break(self):
idx = self.find_frame_idx(ContinueOrBreak)
if idx is not None:
logger.warning(
'Writing code after continue/break may cause undefined behavior in tilelang.',
stack_info=True,
stacklevel=3)
@contextmanager
def with_frame(self, frame: AbstractContextManager[Any] | None):
pop_idx = len(self.frames)
yield self.enter_frame(frame)
while len(self.frames) > pop_idx:
self.frames.pop().__exit__(None, None, None)
class _has_if_frame:
...
def ctx_if(self, cond):
self.check_continue_break()
cond = unwrap_cond(cond)
if isinstance(cond, PrimExpr):
with self.with_frame(tir.If(cond)):
yield self._has_if_frame
else:
with self.with_frame(ConstIfFrame()):
yield cond
def ctx_then(self, val):
if val is self._has_if_frame:
with self.with_frame(tir.Then()):
yield
else:
with self.with_frame(BlockFrame()):
if val:
yield
def ctx_else(self, val):
if val is self._has_if_frame:
with self.with_frame(tir.Else()):
yield
else:
with self.with_frame(BlockFrame()):
if not val:
yield
def eval(self, val: Any):
val = unwrap_expr(val)
if val is None:
pass
elif isinstance(val, tir.frame.IRBuilderFrame):
if isinstance(val, tir.frame.ForFrame):
logger.warning(
'Evaluating a for frame may cause undefined behavior in tilelang.',
stack_info=True,
stacklevel=1,
)
self.enter_frame(val)
elif isinstance(val, PrimExpr):
tir.evaluate(val)
elif isinstance(val, (int, bool)):
tir.evaluate(tvm.tir.const(val))
elif isinstance(val, str):
pass
elif isinstance(val, tvm.tir.stmt.BufferStore):
tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
elif not isinstance(val, tvm.tir.Buffer):
raise TypeError(f"Unsupported eval value: {val} of type {type(val)}")
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
if isinstance(it, SerialForWithStep):
# Validate and compute the trip count before constructing the frame
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
if step_value == 0:
raise ValueError('Invalid stepped serial: step must be non-zero')
if step_value > 0:
real_stop = tir.ceildiv(it.stop - it.start, step_value)
else:
real_stop = tir.ceildiv(it.start - it.stop, -step_value)
else:
logger.warning(
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_stop = tir.ceildiv(it.stop - it.start, it.step)
real_frame = tir.serial(real_stop, annotations=it.annotations)
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
else:
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v
def ctx_continue(self):
self.check_continue_break()
# add a dummy frame for checking code after continue/break
self.enter_frame(ContinueFrame())
tir.evaluate(tir.continue_loop())
def ctx_break(self):
self.check_continue_break()
# add a dummy frame for checking code after continue/break
self.enter_frame(BreakFrame())
tir.evaluate(tir.break_loop())
def ctx_while(self, cond):
self.check_continue_break()
cond_v = cond()
cond_v_unwrap = unwrap_cond(cond_v)
if not isinstance(cond_v_unwrap, PrimExpr):
if cond_v_unwrap:
raise RuntimeError(
f'Infinite while loop detected in TileLang\n'
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n'
)
else:
logger.warning(
'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
stack_info=True,
stacklevel=2)
with self.with_frame(tir.While(cond_v_unwrap)):
yield None
def bind(self, name, value, annot=BaseBuilder.empty):
self.check_continue_break()
locals = self.get_parent_locals()
orig_value = locals.get(name, None)
# annotation like tl.float32
# temporarily disable annotation based var declaration, for better pull request separation
# if callable(annot):
# annot_val = annot()
# if isinstance(annot_val, tir.Var):
# orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var')
# IRBuilder.name(name, orig_value)
# if isinstance(value, EllipsisType) or value is self.empty:
# return orig_value
# elif isinstance(value, (int, float, IntImm, FloatImm)):
# tir.block_attr(
# {'tl.local_var_init': {
# orig_value.data: tvm.runtime.convert(value)
# }})
# return orig_value
# if orig_value is a local.var, we use buffer_store to modify it immutably
# however, if rvalue is also a local.var, this is a new binding,
# we should not use buffer_store, and bind it instead
# ```py
# a = tl.alloc_var('float32') # bind var `a`
# a = tl.alloc_var('float32') # bind a new var `a_1`
# b = a # get value of var `b = a_1[0]``
# c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]`
# ```
if is_var(orig_value) and not is_var(value):
tir.buffer_store(orig_value, value, 0)
return orig_value
res = self.bind_immutable(name, value)
if name != '_':
frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME)
assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames:
logger.warning(
f'Variable `{name}` shadows another declared value, Are you forgetting to allocate it as a var?',
stack_info=True,
stacklevel=2,
)
self.name_inside_frame[name] = self.frames[frame]
return res
def unwrap_value(self, value):
value = unwrap_expr(value)
# handle bx, by = tl.Kernel(128, 128), rval is frame
if isinstance(value, tir.frame.IRBuilderFrame):
return self.enter_frame(value)
else:
return value
def bind_immutable(self, name, value):
if name == '_':
# use _tmp to make the generated tir more readable
name = "_tmp"
if isinstance(value, tir.meta_var):
return value.value
elif isinstance(value, tir.frame.IRBuilderFrame):
if isinstance(value, tir.frame.ForFrame):
logger.warning(
'Binding a for frame to variable may cause undefined behavior in tilelang.',
stack_info=True,
stacklevel=2,
)
return self.enter_frame(value)
elif isinstance(value, (Buffer, tir.IterVar, tir.Var)):
IRBuilder.name(name, value)
return value
elif isinstance(value, (tuple, list, tvm.ffi.Array)):
return value
else:
try:
value = tvm.runtime.convert(value)
except TypeError:
return value
frame = tir.LetStmt(value)
var = frame.var
IRBuilder.name(name, var)
return self.enter_frame(frame)
def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty):
self.check_continue_break()
if annot is not self.empty:
logger.warning(
"Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2)
if isinstance(lval, Buffer):
tir.buffer_store(lval, value, sl)
else:
return super().assign_slice(lval, sl, value)
def aug_assign(self, op, target, aug_value):
self.check_continue_break()
if is_var(target):
tir.buffer_store(target, eval_op(op, target[0], aug_value), 0)
return target
elif isinstance(target, Buffer):
raise RuntimeError("Augmented assignment is not supported for Buffer")
else:
return super().aug_assign(op, target, aug_value)
def aug_assign_slice(self, op, target, sl, aug_value):
self.check_continue_break()
if isinstance(target, Buffer):
tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl)
else:
return super().aug_assign_slice(op, target, sl, aug_value)
def boolop(self, op, left, right):
left = unwrap_cond(left)
if isinstance(left, PrimExpr):
with self.with_frame(BoolOpFrame()):
if op == 'And':
return tir.And(left, right())
if op == 'Or':
return tir.Or(left, right())
raise RuntimeError(f"Unsupported boolean operator: {op}")
else:
return super().boolop(op, left, right)
def ifexp(self, cond, then, otherwise):
cond = unwrap_cond(cond)
if isinstance(cond, PrimExpr):
with self.with_frame(BoolOpFrame()):
return tir.if_then_else(cond, then(), otherwise())
else:
return super().ifexp(cond, then, otherwise)
def ret(self, value):
self.check_continue_break()
# handle return T.alloc_var()
value = self.unwrap_value(value)
last_macro = self.find_frame_idx(MacroFrame)
if last_macro is not None:
frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro)
if frame is not None:
raise NotImplementedError(
"Return from control flow is not supported yet. \n"
"You should allocate a var before the control flow, assign value inside the blocks, \n"
"and return the var after the control flow. i.e.\n"
"```\n"
"@T.macro\n" \
"def my_macro(cond):\n"
" a = T.alloc_var(T.float16)\n"
" if cond:\n"
" a = 1.0\n"
" return a\n"
"```"
)
return value
def ctx_with(self, ctx):
self.check_continue_break()
if isinstance(ctx, tir.frame.IRBuilderFrame):
return self.with_frame(ctx)
else:
return super().ctx_with(ctx)
def assert_expr(self, cond, msg):
self.check_continue_break()
cond = unwrap_cond(cond)
if isinstance(cond, PrimExpr):
self.enter_frame(tir.Assert(cond, msg))
elif not cond:
raise AssertionError(msg)
def rval(self, name: str, value: Any) -> Any:
if name in self.name_inside_frame:
frame = self.name_inside_frame[name]
if frame not in self.frames:
raise RuntimeError(
f"Use immutable variable `{name}` outside its defining region, did you forget **alloc_var**?\n"
f"variable `{name}` is defined in frame: {frame}, current frames: {self.frames}."
)
return self.unwrap_value(value)
def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
if isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
else:
return value
if isinstance(value, (Buffer, Var)):
return tir.arg(name, value)
elif value is self.empty:
raise ValueError(f'Argument `{name}` is not annotated')
# elif isinstance(value, Hashable):
# return value
else:
raise TypeError(
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
def override(self, name: str):
from tilelang.language import serial
if name == 'range':
return serial
raise ValueError(f'Unknown override: {name}')
_P = ParamSpec('_P')
_T = TypeVar('_T')
if TYPE_CHECKING:
class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc):
params: list[tvm.tir.Var | tvm.tir.Buffer]
body: tvm.tir.Stmt
ret_type: tvm.ir.Type
buffer_map: Map[tvm.tir.Var, tvm.tir.Buffer]
attrs: tvm.Attrs | None
span: Span | None
ir_gen: IRGenerator[_P, _T] | None
source: str | None
orig_func: Callable[_P, _T] | None
else:
PrimFunc = tvm.tir.PrimFunc
@dataclass
class Macro(Generic[_P, _T]):
name: str
orig_func: Callable[_P, _T]
ir_gen: IRGenerator[_P, _T]
@property
def source(self) -> str:
return self.ir_gen.source
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
builder = Builder.current()
with builder.macro(self.name):
res = self.ir_gen.gen(builder)(*args, **kwargs)
return res
def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
"""
Decorator that converts a Python function into a TileLang macro.
TileLang macro is very similar to PrimFunc, it can be used in prim_func or another macro.
Parameters
----------
func : Callable[_P, _T]
The Python function to be converted into a macro. This function will be analyzed
and transformed into an IR generation function. The function can take any parameters
(_P) and return any type (_T).
Returns
-------
Macro[_P, _T]
A Macro object that wraps the original function with IR generation capabilities.
The returned Macro preserves the original function's signature (parameters _P and
return type _T) while adding metaprogramming capabilities.
Example:
--------
>>> @macro
... def my_macro(x: T.int32) -> T.int32:
... return x ** 2
>>> @prim_func
... def my_func(A: T.Tensor((10,), T.int32), B: T.Tensor((10,), T.int32)):
... with T.Kernel(1) as _:
... for i in T.serial(10):
... B[i] = my_macro(A[i])
See Also
--------
Macro : The class that wraps macro functions
mutate : The function that transforms Python code into IR generators
"""
def impl(func: Callable[_P, _T]) -> Macro[_P, _T]:
return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func))
return impl(func) if func is not None else impl
from typing import _eval_type
def get_type_hints(func):
annot = getattr(func, '__annotations__', None)
if annot is None:
raise TypeError(f'Failed to get function type hints, {func} is not a function')
hints = {}
# Build eval namespaces from function globals plus captured closure variables
# This lets annotations reference symbols like `n`, `h`, or dtype vars
# defined in the outer scope of a nested function.
globalns = dict(getattr(func, '__globals__', {}))
localns = dict(globalns)
try:
freevars = getattr(func.__code__, 'co_freevars', ())
cells = getattr(func, '__closure__', ()) or ()
closure_bindings = {
name: cell.cell_contents for name, cell in zip(freevars, cells) if name not in localns
}
if closure_bindings:
localns.update(closure_bindings)
# Also update globals so ForwardRef eval sees them uniformly
globalns.update(closure_bindings)
except Exception:
# Be permissive: absence or access issues with closure shouldn't crash
pass
for name, value in annot.items():
if name == 'return':
continue
if isinstance(value, tvm.DataType):
hints[name] = value
continue
if value is None:
value = type(None)
if isinstance(value, str):
# Handle simple dtype aliases like T.float32 appearing as strings
# Evaluate directly only when it matches known dtypes
try:
_, v = value.split('.', maxsplit=1)
except ValueError:
v = value
if v in dt._all_dtypes:
try:
hints[name] = eval(value, globalns, localns)
continue
except Exception:
pass
value = ForwardRef(value, is_argument=True, is_class=False)
hints[name] = _eval_type(value, globalns=globalns, localns=localns)
return hints
def _is_static_annot(annot: Any) -> bool:
return isinstance(annot, (dt.dtype, Buffer, Var))
def prim_func(func: Callable[_P, _T] = None,
*,
generator: bool = False) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]:
"""
Decorator to create a primitive function (PrimFunc) for TileLang IR generation.
This decorator transforms a Python function into a TileLang primitive function by analyzing
its type annotations and generating intermediate representation (IR) code. It supports both
immediate construction (when all parameters are statically annotated) and generator mode
(for dynamic construction).
Parameters
----------
func : Callable[_P, _T], optional
The function to be decorated. Can be None when using decorator with arguments.
generator : bool, default=False
If True, returns a generator function that creates PrimFunc instances on demand.
If False, attempts to create a PrimFunc immediately using type annotations.
Returns
-------
PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]
- If `generator=False` and all parameters are statically annotated: returns a PrimFunc instance
- If `generator=True`: returns a callable that generates PrimFunc instances when invoked
- If used without parentheses: returns the decorator implementation function
Examples
--------
Static annotation mode (immediate construction):
>>> @prim_func
... def add_kernel(A: T.Buffer((128,), T.float32),
... B: T.Buffer((128,), T.float32)):
... for i in T.grid(128):
... B[i] = A[i] + 1.0
Generator mode (dynamic construction):
>>> @prim_func(generator=True)
... def dynamic_kernel(A=T.Tensor((128,), T.float32)):
... # function body
... pass
>>> kernel_instance = dynamic_kernel()
With custom parameters:
>>> @prim_func(generator=True)
... def parameterized_kernel(size: int = 128):
... # function body using size parameter
... pass
>>> kernel = parameterized_kernel(size=256)
See Also
--------
Builder : The IR builder class used for constructing primitive functions
mutate : Function used to generate IR from the decorated function
"""
def impl(func: Callable[_P, _T]) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]:
sig = inspect.signature(func)
annot = get_type_hints(func)
for k in annot:
if callable(annot[k]):
annot[k] = annot[k]()
# check whether all arguments are annotated
all_arg_annotated = all([x in annot for x in sig.parameters])
# check whether all annotations are Buffer/Var/dtype
all_annot_are_static = all([_is_static_annot(x) for x in annot.values()])
ir_gen = mutate(func)
def prim_func_generator(*args, **kwargs):
builder = Builder()
with builder.prim_func(func.__name__):
ir_gen.gen(builder)(*args, **kwargs)
res = builder.get()
res.ir_gen = ir_gen
res.source = ir_gen.source
res.orig_func = func
return res
prim_func_generator.ir_gen = ir_gen
prim_func_generator.source = ir_gen.source
prim_func_generator.orig_func = func
if generator:
return prim_func_generator
if all_arg_annotated and all_annot_are_static:
return prim_func_generator(**annot)
else:
raise ValueError(
"Some arguments are not supported or statically annotated, \n"
"please check the annotations or set generator=True to get a prim_func generator.\n"
f"Argument Annotations: {annot}\n"
"Example usage of generator:\n"
"```py\n"
"@prim_func(generator=True)\n"
"def my_func(a=T.Tensor((128,), T.float32)): ...\n"
"return my_func()\n"
"```")
return impl(func) if func is not None else impl
from tilelang import tvm
from tvm import ir
import torch
import ctypes
from typing import TYPE_CHECKING, Union
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
dtype = tvm.DataType
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
# Base dtype conversion list
_dtype_cvt_base = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
(bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.half, 'float16', None, None, 'Float16'),
(torch.float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.double, 'float64', ctypes.c_double, 'double', 'Float64'),
# (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype')
(torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'),
(torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'),
(torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'),
(torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'),
(torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'),
(torch.float16, 'float16', None, None, 'Float16'),
(torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'),
(None, 'float8_e4m3', None, None, 'Float8E4M3'),
(torch.bfloat16, 'bfloat16', None, None, 'BFloat16'),
]
# Dynamically add fp8-related types if they exist in torch
_fp8_dtype_mappings = [
('float8_e4m3fn', 'Float8E4M3FN'),
('float8_e4m3fnuz', 'Float8E4M3FNUZ'),
('float8_e5m2', 'Float8E5M2'),
('float8_e5m2fnuz', 'Float8E5M2FNUZ'),
('float8_e8m0fnu', 'Float8E8M0FNU'),
]
_dtype_cvt = list(_dtype_cvt_base)
for torch_attr_name, tvm_name in _fp8_dtype_mappings:
if hasattr(torch, torch_attr_name):
torch_dtype = getattr(torch, torch_attr_name)
_dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name))
def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x):
return {
smapper(item[sidx]): dmapper(item[didx])
for item in _dtype_cvt
if item[didx] is not None and item[sidx] is not None
}
_dtype_py2tvmstr = _create_type_mapper(0, 1)
_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x))
_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x))
_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x))
_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x))
def __dtype_eq__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__eq__(self, other)
if other in _dtype_py2tvmstr:
return str.__eq__(self, _dtype_py2tvmstr[other])
return NotImplemented
def __dtype_ne__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__ne__(self, other)
if other in _dtype_py2tvmstr:
return str.__ne__(self, _dtype_py2tvmstr[other])
return NotImplemented
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
if self in _dtype_tvmstr2fficall:
return _dtype_tvmstr2fficall[self](expr, is_size_var)
# try to construct the ffi call
if self.startswith('uint'):
val = 'UInt' + self[4:]
elif self.startswith('int'):
val = 'Int' + self[3:]
elif self.startswith('float'):
val = 'Float' + self[5:]
elif self.startswith('bfloat'):
val = 'BFloat' + self[6:]
else:
raise TypeError(f'Invalid type {self}')
if '_' in val:
first, second = val.split('_', maxsplit=1)
val = first + second.upper()
call = getattr(tb_ffi, val, None)
if call is None:
raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n"
f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`")
return call(expr, is_size_var)
__orig_dtype_new = dtype.__new__
def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str):
return __orig_dtype_new(cls, value)
elif value in _dtype_py2tvmstr:
return __orig_dtype_new(cls, _dtype_py2tvmstr[value])
else:
expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values()))
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
dtype.__eq__ = __dtype_eq__
dtype.__req__ = __dtype_eq__
dtype.__ne__ = __dtype_ne__
dtype.__rne__ = __dtype_ne__
dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__
def get_tvm_dtype(value: AnyDType) -> dtype:
if isinstance(value, (dtype, ir.Type)):
return value
return dtype(value)
if TYPE_CHECKING:
# yapf: disable
class bool(dtype): ...
class short(dtype): ...
class int(dtype): ...
class long(dtype): ...
class half(dtype): ...
class float(dtype): ...
class double(dtype): ...
class int8(dtype): ...
class int16(dtype): ...
class int32(dtype): ...
class int64(dtype): ...
class int8x4(dtype): ...
class int16x4(dtype): ...
class int32x4(dtype): ...
class int64x4(dtype): ...
class int8x8(dtype): ...
class int16x8(dtype): ...
class int32x8(dtype): ...
class int64x8(dtype): ...
class int8x16(dtype): ...
class int16x16(dtype): ...
class int32x16(dtype): ...
class int64x16(dtype): ...
class int8x32(dtype): ...
class int16x32(dtype): ...
class int32x32(dtype): ...
class int64x32(dtype): ...
class int8x64(dtype): ...
class int16x64(dtype): ...
class int32x64(dtype): ...
class int64x64(dtype): ...
class uint8(dtype): ...
class uint16(dtype): ...
class uint32(dtype): ...
class uint64(dtype): ...
class uint8x4(dtype): ...
class uint16x4(dtype): ...
class uint32x4(dtype): ...
class uint64x4(dtype): ...
class uint8x8(dtype): ...
class uint16x8(dtype): ...
class uint32x8(dtype): ...
class uint64x8(dtype): ...
class uint8x16(dtype): ...
class uint16x16(dtype): ...
class uint32x16(dtype): ...
class uint64x16(dtype): ...
class uint8x32(dtype): ...
class uint16x32(dtype): ...
class uint32x32(dtype): ...
class uint64x32(dtype): ...
class uint8x64(dtype): ...
class uint16x64(dtype): ...
class uint32x64(dtype): ...
class uint64x64(dtype): ...
class float16(dtype): ...
class float32(dtype): ...
class float64(dtype): ...
class float16x2(dtype): ...
class float32x2(dtype): ...
class float64x2(dtype): ...
class float16x4(dtype): ...
class float32x4(dtype): ...
class float64x4(dtype): ...
class float16x8(dtype): ...
class float32x8(dtype): ...
class float64x8(dtype): ...
class float16x16(dtype): ...
class float32x16(dtype): ...
class float64x16(dtype): ...
class float16x32(dtype): ...
class float32x32(dtype): ...
class float64x32(dtype): ...
class float16x64(dtype): ...
class float32x64(dtype): ...
class float64x64(dtype): ...
class float8_e3m4(dtype): ...
class float8_e3m4x2(dtype): ...
class float8_e3m4x4(dtype): ...
class float8_e3m4x8(dtype): ...
class float8_e3m4x16(dtype): ...
class float8_e3m4x32(dtype): ...
class float8_e3m4x64(dtype): ...
class float8_e4m3(dtype): ...
class float8_e4m3x2(dtype): ...
class float8_e4m3x4(dtype): ...
class float8_e4m3x8(dtype): ...
class float8_e4m3x16(dtype): ...
class float8_e4m3x32(dtype): ...
class float8_e4m3x64(dtype): ...
class float8_e4m3b11fnuz(dtype): ...
class float8_e4m3b11fnuzx2(dtype): ...
class float8_e4m3b11fnuzx4(dtype): ...
class float8_e4m3b11fnuzx8(dtype): ...
class float8_e4m3b11fnuzx16(dtype): ...
class float8_e4m3b11fnuzx32(dtype): ...
class float8_e4m3b11fnuzx64(dtype): ...
class float8_e4m3fn(dtype): ...
class float8_e4m3fnx2(dtype): ...
class float8_e4m3fnx4(dtype): ...
class float8_e4m3fnx8(dtype): ...
class float8_e4m3fnx16(dtype): ...
class float8_e4m3fnx32(dtype): ...
class float8_e4m3fnx64(dtype): ...
class float8_e4m3fnuz(dtype): ...
class float8_e4m3fnuzx2(dtype): ...
class float8_e4m3fnuzx4(dtype): ...
class float8_e4m3fnuzx8(dtype): ...
class float8_e4m3fnuzx16(dtype): ...
class float8_e4m3fnuzx32(dtype): ...
class float8_e4m3fnuzx64(dtype): ...
class float8_e5m2(dtype): ...
class float8_e5m2x2(dtype): ...
class float8_e5m2x4(dtype): ...
class float8_e5m2x8(dtype): ...
class float8_e5m2x16(dtype): ...
class float8_e5m2x32(dtype): ...
class float8_e5m2x64(dtype): ...
class float8_e5m2fnuz(dtype): ...
class float8_e5m2fnuzx2(dtype): ...
class float8_e5m2fnuzx4(dtype): ...
class float8_e5m2fnuzx8(dtype): ...
class float8_e5m2fnuzx16(dtype): ...
class float8_e5m2fnuzx32(dtype): ...
class float8_e5m2fnuzx64(dtype): ...
class float8_e8m0fnu(dtype): ...
class float8_e8m0fnux2(dtype): ...
class float8_e8m0fnux4(dtype): ...
class float8_e8m0fnux8(dtype): ...
class float8_e8m0fnux16(dtype): ...
class float8_e8m0fnux32(dtype): ...
class float8_e8m0fnux64(dtype): ...
class float6_e2m3fn(dtype): ...
class float6_e2m3fnx2(dtype): ...
class float6_e2m3fnx4(dtype): ...
class float6_e2m3fnx8(dtype): ...
class float6_e2m3fnx16(dtype): ...
class float6_e2m3fnx32(dtype): ...
class float6_e2m3fnx64(dtype): ...
class float6_e3m2fn(dtype): ...
class float6_e3m2fnx2(dtype): ...
class float6_e3m2fnx4(dtype): ...
class float6_e3m2fnx8(dtype): ...
class float6_e3m2fnx16(dtype): ...
class float6_e3m2fnx32(dtype): ...
class float6_e3m2fnx64(dtype): ...
class float4_e2m1fn(dtype): ...
class float4_e2m1fnx2(dtype): ...
class float4_e2m1fnx4(dtype): ...
class float4_e2m1fnx8(dtype): ...
class float4_e2m1fnx16(dtype): ...
class float4_e2m1fnx32(dtype): ...
class float4_e2m1fnx64(dtype): ...
class bfloat16(dtype): ...
# yapf: enable
else:
bool = dtype('bool')
short = dtype('int16')
int = dtype('int32')
long = dtype('int64')
half = dtype('float16')
float = dtype('float32')
double = dtype('float64')
int8 = dtype('int8')
int16 = dtype('int16')
int32 = dtype('int32')
int64 = dtype('int64')
int8x4 = dtype('int8x4')
int16x4 = dtype('int16x4')
int32x4 = dtype('int32x4')
int64x4 = dtype('int64x4')
int8x8 = dtype('int8x8')
int16x8 = dtype('int16x8')
int32x8 = dtype('int32x8')
int64x8 = dtype('int64x8')
int8x16 = dtype('int8x16')
int16x16 = dtype('int16x16')
int32x16 = dtype('int32x16')
int64x16 = dtype('int64x16')
int8x32 = dtype('int8x32')
int16x32 = dtype('int16x32')
int32x32 = dtype('int32x32')
int64x32 = dtype('int64x32')
int8x64 = dtype('int8x64')
int16x64 = dtype('int16x64')
int32x64 = dtype('int32x64')
int64x64 = dtype('int64x64')
uint8 = dtype('uint8')
uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
uint8x4 = dtype('uint8x4')
uint16x4 = dtype('uint16x4')
uint32x4 = dtype('uint32x4')
uint64x4 = dtype('uint64x4')
uint8x8 = dtype('uint8x8')
uint16x8 = dtype('uint16x8')
uint32x8 = dtype('uint32x8')
uint64x8 = dtype('uint64x8')
uint8x16 = dtype('uint8x16')
uint16x16 = dtype('uint16x16')
uint32x16 = dtype('uint32x16')
uint64x16 = dtype('uint64x16')
uint8x32 = dtype('uint8x32')
uint16x32 = dtype('uint16x32')
uint32x32 = dtype('uint32x32')
uint64x32 = dtype('uint64x32')
uint8x64 = dtype('uint8x64')
uint16x64 = dtype('uint16x64')
uint32x64 = dtype('uint32x64')
uint64x64 = dtype('uint64x64')
float16 = dtype('float16')
float32 = dtype('float32')
float64 = dtype('float64')
float16x2 = dtype('float16x2')
float32x2 = dtype('float32x2')
float64x2 = dtype('float64x2')
float16x4 = dtype('float16x4')
float32x4 = dtype('float32x4')
float64x4 = dtype('float64x4')
float16x8 = dtype('float16x8')
float32x8 = dtype('float32x8')
float64x8 = dtype('float64x8')
float16x16 = dtype('float16x16')
float32x16 = dtype('float32x16')
float64x16 = dtype('float64x16')
float16x32 = dtype('float16x32')
float32x32 = dtype('float32x32')
float64x32 = dtype('float64x32')
float16x64 = dtype('float16x64')
float32x64 = dtype('float32x64')
float64x64 = dtype('float64x64')
float8_e3m4 = dtype('float8_e3m4')
float8_e3m4x2 = dtype('float8_e3m4x2')
float8_e3m4x4 = dtype('float8_e3m4x4')
float8_e3m4x8 = dtype('float8_e3m4x8')
float8_e3m4x16 = dtype('float8_e3m4x16')
float8_e3m4x32 = dtype('float8_e3m4x32')
float8_e3m4x64 = dtype('float8_e3m4x64')
float8_e4m3 = dtype('float8_e4m3')
float8_e4m3x2 = dtype('float8_e4m3x2')
float8_e4m3x4 = dtype('float8_e4m3x4')
float8_e4m3x8 = dtype('float8_e4m3x8')
float8_e4m3x16 = dtype('float8_e4m3x16')
float8_e4m3x32 = dtype('float8_e4m3x32')
float8_e4m3x64 = dtype('float8_e4m3x64')
float8_e4m3b11fnuz = dtype('float8_e4m3b11fnuz')
float8_e4m3b11fnuzx2 = dtype('float8_e4m3b11fnuzx2')
float8_e4m3b11fnuzx4 = dtype('float8_e4m3b11fnuzx4')
float8_e4m3b11fnuzx8 = dtype('float8_e4m3b11fnuzx8')
float8_e4m3b11fnuzx16 = dtype('float8_e4m3b11fnuzx16')
float8_e4m3b11fnuzx32 = dtype('float8_e4m3b11fnuzx32')
float8_e4m3b11fnuzx64 = dtype('float8_e4m3b11fnuzx64')
float8_e4m3fn = dtype('float8_e4m3fn')
float8_e4m3fnx2 = dtype('float8_e4m3fnx2')
float8_e4m3fnx4 = dtype('float8_e4m3fnx4')
float8_e4m3fnx8 = dtype('float8_e4m3fnx8')
float8_e4m3fnx16 = dtype('float8_e4m3fnx16')
float8_e4m3fnx32 = dtype('float8_e4m3fnx32')
float8_e4m3fnx64 = dtype('float8_e4m3fnx64')
float8_e4m3fnuz = dtype('float8_e4m3fnuz')
float8_e4m3fnuzx2 = dtype('float8_e4m3fnuzx2')
float8_e4m3fnuzx4 = dtype('float8_e4m3fnuzx4')
float8_e4m3fnuzx8 = dtype('float8_e4m3fnuzx8')
float8_e4m3fnuzx16 = dtype('float8_e4m3fnuzx16')
float8_e4m3fnuzx32 = dtype('float8_e4m3fnuzx32')
float8_e4m3fnuzx64 = dtype('float8_e4m3fnuzx64')
float8_e5m2 = dtype('float8_e5m2')
float8_e5m2x2 = dtype('float8_e5m2x2')
float8_e5m2x4 = dtype('float8_e5m2x4')
float8_e5m2x8 = dtype('float8_e5m2x8')
float8_e5m2x16 = dtype('float8_e5m2x16')
float8_e5m2x32 = dtype('float8_e5m2x32')
float8_e5m2x64 = dtype('float8_e5m2x64')
float8_e5m2fnuz = dtype('float8_e5m2fnuz')
float8_e5m2fnuzx2 = dtype('float8_e5m2fnuzx2')
float8_e5m2fnuzx4 = dtype('float8_e5m2fnuzx4')
float8_e5m2fnuzx8 = dtype('float8_e5m2fnuzx8')
float8_e5m2fnuzx16 = dtype('float8_e5m2fnuzx16')
float8_e5m2fnuzx32 = dtype('float8_e5m2fnuzx32')
float8_e5m2fnuzx64 = dtype('float8_e5m2fnuzx64')
float8_e8m0fnu = dtype('float8_e8m0fnu')
float8_e8m0fnux2 = dtype('float8_e8m0fnux2')
float8_e8m0fnux4 = dtype('float8_e8m0fnux4')
float8_e8m0fnux8 = dtype('float8_e8m0fnux8')
float8_e8m0fnux16 = dtype('float8_e8m0fnux16')
float8_e8m0fnux32 = dtype('float8_e8m0fnux32')
float8_e8m0fnux64 = dtype('float8_e8m0fnux64')
float6_e2m3fn = dtype('float6_e2m3fn')
float6_e2m3fnx2 = dtype('float6_e2m3fnx2')
float6_e2m3fnx4 = dtype('float6_e2m3fnx4')
float6_e2m3fnx8 = dtype('float6_e2m3fnx8')
float6_e2m3fnx16 = dtype('float6_e2m3fnx16')
float6_e2m3fnx32 = dtype('float6_e2m3fnx32')
float6_e2m3fnx64 = dtype('float6_e2m3fnx64')
float6_e3m2fn = dtype('float6_e3m2fn')
float6_e3m2fnx2 = dtype('float6_e3m2fnx2')
float6_e3m2fnx4 = dtype('float6_e3m2fnx4')
float6_e3m2fnx8 = dtype('float6_e3m2fnx8')
float6_e3m2fnx16 = dtype('float6_e3m2fnx16')
float6_e3m2fnx32 = dtype('float6_e3m2fnx32')
float6_e3m2fnx64 = dtype('float6_e3m2fnx64')
float4_e2m1fn = dtype('float4_e2m1fn')
float4_e2m1fnx2 = dtype('float4_e2m1fnx2')
float4_e2m1fnx4 = dtype('float4_e2m1fnx4')
float4_e2m1fnx8 = dtype('float4_e2m1fnx8')
float4_e2m1fnx16 = dtype('float4_e2m1fnx16')
float4_e2m1fnx32 = dtype('float4_e2m1fnx32')
float4_e2m1fnx64 = dtype('float4_e2m1fnx64')
bfloat16 = dtype('bfloat16')
_all_dtypes = {
'bool',
'short',
'int',
'long',
'half',
'float',
'double',
'int8',
'int16',
'int32',
'int64',
'int8x4',
'int16x4',
'int32x4',
'int64x4',
'int8x8',
'int16x8',
'int32x8',
'int64x8',
'int8x16',
'int16x16',
'int32x16',
'int64x16',
'int8x32',
'int16x32',
'int32x32',
'int64x32',
'int8x64',
'int16x64',
'int32x64',
'int64x64',
'uint8',
'uint16',
'uint32',
'uint64',
'uint8x4',
'uint16x4',
'uint32x4',
'uint64x4',
'uint8x8',
'uint16x8',
'uint32x8',
'uint64x8',
'uint8x16',
'uint16x16',
'uint32x16',
'uint64x16',
'uint8x32',
'uint16x32',
'uint32x32',
'uint64x32',
'uint8x64',
'uint16x64',
'uint32x64',
'uint64x64',
'float16',
'float32',
'float64',
'float16x2',
'float32x2',
'float64x2',
'float16x4',
'float32x4',
'float64x4',
'float16x8',
'float32x8',
'float64x8',
'float16x16',
'float32x16',
'float64x16',
'float16x32',
'float32x32',
'float64x32',
'float16x64',
'float32x64',
'float64x64',
'float8_e3m4',
'float8_e3m4x2',
'float8_e3m4x4',
'float8_e3m4x8',
'float8_e3m4x16',
'float8_e3m4x32',
'float8_e3m4x64',
'float8_e4m3',
'float8_e4m3x2',
'float8_e4m3x4',
'float8_e4m3x8',
'float8_e4m3x16',
'float8_e4m3x32',
'float8_e4m3x64',
'float8_e4m3b11fnuz',
'float8_e4m3b11fnuzx2',
'float8_e4m3b11fnuzx4',
'float8_e4m3b11fnuzx8',
'float8_e4m3b11fnuzx16',
'float8_e4m3b11fnuzx32',
'float8_e4m3b11fnuzx64',
'float8_e4m3fn',
'float8_e4m3fnx2',
'float8_e4m3fnx4',
'float8_e4m3fnx8',
'float8_e4m3fnx16',
'float8_e4m3fnx32',
'float8_e4m3fnx64',
'float8_e4m3fnuz',
'float8_e4m3fnuzx2',
'float8_e4m3fnuzx4',
'float8_e4m3fnuzx8',
'float8_e4m3fnuzx16',
'float8_e4m3fnuzx32',
'float8_e4m3fnuzx64',
'float8_e5m2',
'float8_e5m2x2',
'float8_e5m2x4',
'float8_e5m2x8',
'float8_e5m2x16',
'float8_e5m2x32',
'float8_e5m2x64',
'float8_e5m2fnuz',
'float8_e5m2fnuzx2',
'float8_e5m2fnuzx4',
'float8_e5m2fnuzx8',
'float8_e5m2fnuzx16',
'float8_e5m2fnuzx32',
'float8_e5m2fnuzx64',
'float8_e8m0fnu',
'float8_e8m0fnux2',
'float8_e8m0fnux4',
'float8_e8m0fnux8',
'float8_e8m0fnux16',
'float8_e8m0fnux32',
'float8_e8m0fnux64',
'float6_e2m3fn',
'float6_e2m3fnx2',
'float6_e2m3fnx4',
'float6_e2m3fnx8',
'float6_e2m3fnx16',
'float6_e2m3fnx32',
'float6_e2m3fnx64',
'float6_e3m2fn',
'float6_e3m2fnx2',
'float6_e3m2fnx4',
'float6_e3m2fnx8',
'float6_e3m2fnx16',
'float6_e3m2fnx32',
'float6_e3m2fnx64',
'float4_e2m1fn',
'float4_e2m1fnx2',
'float4_e2m1fnx4',
'float4_e2m1fnx8',
'float4_e2m1fnx16',
'float4_e2m1fnx32',
'float4_e2m1fnx64',
'bfloat16',
}
__all__ = list(_all_dtypes) + [
'dtype',
'AnyDType',
'get_tvm_dtype',
]
from __future__ import annotations
import ast
import inspect
from typing import Any, Callable, Literal
from tilelang import env
from hashlib import sha256
import linecache
def disk_compile(source, name):
cache_dir = env.TILELANG_CACHE_DIR
if cache_dir is not None:
import os
save_dir = os.path.join(cache_dir, "py-cache")
os.makedirs(save_dir, exist_ok=True)
hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8]
path = os.path.join(save_dir, f"{name}.{hash_sfx}.py")
with open(path, 'w') as f:
f.write(source)
linecache.cache[path] = (len(source), None, source.splitlines(), path)
return compile(source, path, "exec")
def _remove_leading_ident(source: str):
lines = source.splitlines()
if not lines:
return source
ident_size = len(lines[0]) - len(lines[0].lstrip())
return "\n".join([line[ident_size:] if len(line) >= ident_size else line for line in lines])
def get_func_nonlocals(func):
"""A modified version of `inspect.getclosurevars`"""
if inspect.ismethod(func):
func = func.__func__
if not inspect.isfunction(func):
raise TypeError(f"{func!r} is not a Python function")
code = func.__code__
# Nonlocal references are named in co_freevars and resolved
# by looking them up in __closure__ by positional index
nonlocal_vars = {}
if func.__closure__ is not None:
for var, cell in zip(code.co_freevars, func.__closure__):
try:
nonlocal_vars[var] = cell.cell_contents
except ValueError as err:
# cell_contents may raise ValueError if the cell is empty.
if "empty" not in str(err):
raise
return nonlocal_vars
def inspect_function_capture(func: Callable) -> dict[str, Any]:
"""Capture function non-locals and global variables.
Parameters
----------
func : Callable
The function to inspect.
Returns
-------
res : Dict[str, Any]
The function variables map with non-local or global variables.
"""
captured = {
**func.__globals__, # type: ignore
**get_func_nonlocals(func),
}
return captured
def get_ast(func: Callable):
_, start = inspect.getsourcelines(func)
filename = inspect.getsourcefile(func) or inspect.getfile(func)
source = inspect.getsource(func)
source = _remove_leading_ident(source)
source = '\n' * (start - 1) + source
tree = ast.parse(source, filename=filename)
return tree
CompileMethod = Literal['direct', 'disk']
def get_compiled_object(source: str | ast.AST,
name: str,
filename: str = None,
globals: dict[str, Any] = None):
if isinstance(source, ast.AST):
assert filename is not None, "filename must be provided when source is an AST"
try:
if isinstance(source, ast.AST):
ast.fix_missing_locations(source)
compiled = compile(source, filename, 'exec')
else:
compiled = disk_compile(source, name)
except Exception as e:
source_str = source if isinstance(source, str) else ast.unparse(source)
raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e
locs = {}
exec(compiled, globals, locs)
return locs[name]
......@@ -5,7 +5,9 @@ from .layout import Layout # noqa: F401
from .fragment import Fragment # noqa: F401
from .swizzle import (
make_swizzled_layout, # noqa: F401
make_volta_swizzled_layout, # noqa: F401
make_wgmma_swizzled_layout, # noqa: F401
make_tcgen05mma_swizzled_layout, # noqa: F401
make_full_bank_swizzled_layout, # noqa: F401
make_half_bank_swizzled_layout, # noqa: F401
make_quarter_bank_swizzled_layout, # noqa: F401
......
......@@ -3,13 +3,14 @@
from __future__ import annotations
import tvm
import tvm_ffi
from tvm.ir import Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
from tilelang.layout import Layout
@tvm.ffi.register_object("tl.Fragment")
@tvm_ffi.register_object("tl.Fragment")
class Fragment(Layout):
"""
A Fragment layout object that encapsulates iteration variables (forward_vars),
......
......@@ -2,14 +2,14 @@
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm
import tvm_ffi
from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
# Register the Layout class as a TVM object under the name "tl.Layout"
@tvm.ffi.register_object("tl.Layout")
@tvm_ffi.register_object("tl.Layout")
class Layout(Node):
def __init__(self, shape, forward_fn):
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm
from tvm.tir import Buffer, BufferLoad, BufferRegion
from tilelang import _ffi_api
def _get_buffer_info(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion
) -> tuple[Buffer, list[int], str]:
"""
Extract buffer, shape, and dtype from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
tuple: (buffer, shape, dtype)
"""
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region, buffer_or_load_or_region.shape, buffer_or_load_or_region.dtype
elif isinstance(buffer_or_load_or_region, (BufferLoad, BufferRegion)):
buf = buffer_or_load_or_region.buffer
return buf, buf.shape, buf.dtype
else:
raise TypeError(
f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
def _get_stride_continuous(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
"""
Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
tuple: (stride, continuous) as integers
"""
_, shape, _ = _get_buffer_info(buffer_or_load_or_region)
stride = int(shape[-2])
continuous = int(shape[-1])
return stride, continuous
def _get_element_size(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> int:
"""
Get element size in bits from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
int: Element size in bits
"""
_, _, dtype = _get_buffer_info(buffer_or_load_or_region)
return int(tvm.DataType(dtype).bits)
# Use a stable swizzled layout to ensure consistent memory access patterns.
# Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied.
def make_swizzled_layout(buffer: tvm.tir.Buffer, k_major: bool = True, allow_pad: bool = True):
assert len(buffer.shape) == 2
def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
k_major: bool = True,
allow_pad: bool = True):
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
return _ffi_api.make_swizzled_layout(
int(buffer.shape[0]),
int(buffer.shape[1]),
int(tvm.DataType(buffer.dtype).bits),
stride,
continuous,
element_size,
k_major,
allow_pad,
)
# for Volta Intrinsics
def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
is_a: bool = True,
k_inner: bool = True):
stride, continuous = _get_stride_continuous(buffer)
return _ffi_api.make_volta_swizzled_layout(
stride,
continuous,
is_a,
k_inner,
)
# for WGMMA Intrinsics
def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
continuity: int = None,
k_major: bool = True):
assert len(buffer.shape) == 2
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
if continuity is None:
continuity = int(buffer.shape[1])
continuity = continuous
return _ffi_api.make_wgmma_swizzled_layout(
int(buffer.shape[0]),
int(buffer.shape[1]),
stride,
continuous,
continuity,
int(tvm.DataType(buffer.dtype).bits),
element_size,
k_major,
)
# for TCGEN05MMA Intrinsics
def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
continuity: int = None,
k_major: bool = True):
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
if continuity is None:
continuity = continuous
return _ffi_api.make_tcgen05mma_swizzled_layout(
stride,
continuous,
continuity,
element_size,
k_major,
)
......@@ -39,15 +128,14 @@ def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
def make_full_bank_swizzled_layout(*args):
"""
Args:
args: buffer or (stride, continuous, element_size)
args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
Examples:
make_full_bank_swizzled_layout(buffer)
make_full_bank_swizzled_layout(stride, continuous, element_size)
"""
if len(args) == 1:
buffer = args[0]
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1])
element_size = int(tvm.DataType(buffer.dtype).bits)
stride, continuous = _get_stride_continuous(args[0])
element_size = _get_element_size(args[0])
elif len(args) == 3:
stride, continuous, element_size = args
else:
......@@ -64,15 +152,14 @@ def make_full_bank_swizzled_layout(*args):
def make_half_bank_swizzled_layout(*args):
"""
Args:
args: buffer or (stride, continuous, element_size)
args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
Examples:
make_half_bank_swizzled_layout(buffer)
make_half_bank_swizzled_layout(stride, continuous, element_size)
"""
if len(args) == 1:
buffer = args[0]
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1])
element_size = int(tvm.DataType(buffer.dtype).bits)
stride, continuous = _get_stride_continuous(args[0])
element_size = _get_element_size(args[0])
elif len(args) == 3:
stride, continuous, element_size = args
else:
......@@ -89,15 +176,14 @@ def make_half_bank_swizzled_layout(*args):
def make_quarter_bank_swizzled_layout(*args):
"""
Args:
args: buffer or (stride, continuous, element_size)
args: buffer/BufferLoad/BufferRegion or (stride, continuous, element_size)
Examples:
make_quarter_bank_swizzled_layout(buffer)
make_quarter_bank_swizzled_layout(stride, continuous, element_size)
"""
if len(args) == 1:
buffer = args[0]
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1])
element_size = int(tvm.DataType(buffer.dtype).bits)
stride, continuous = _get_stride_continuous(args[0])
element_size = _get_element_size(args[0])
elif len(args) == 3:
stride, continuous, element_size = args
else:
......@@ -112,14 +198,13 @@ def make_quarter_bank_swizzled_layout(*args):
def make_linear_layout(*args):
"""
Args:
args: buffer or (stride, continuous)
args: buffer/BufferLoad/BufferRegion or (stride, continuous)
Examples:
make_linear_layout(buffer)
make_linear_layout(stride, continuous)
"""
if len(args) == 1:
buffer = args[0]
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1])
stride, continuous = _get_stride_continuous(args[0])
elif len(args) == 2:
stride, continuous = args
else:
......
......@@ -4,20 +4,24 @@ from tvm import tir
from tvm.target import Target
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
from tilelang.ir import GemmWarpPolicy
import tvm_ffi
from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
from .gemm_mma import GemmMMA
from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA
from .gemm_tcgen05 import GemmTCGEN5
from .gemm_mfma import GemmMFMA
from tilelang import _ffi_api
from tilelang.utils.target import target_is_volta
@tvm.ffi.register_func("tl.gemm_py.infer_layout")
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py, target, thread_bounds):
thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums)
@tvm.ffi.register_func("tl.gemm_py.lower")
@tvm_ffi.register_global_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
......@@ -28,55 +32,117 @@ def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
# same definition with src/op/gemm_py.h
class GemmInst(IntEnum):
MMA = 0
WGMMMA = 1
MFMA = 2
WGMMA = 1
TCGEN5MMA = 2
MFMA = 3
def is_mma(self) -> bool:
return self == GemmInst.MMA
def is_wgmma(self) -> bool:
return self == GemmInst.WGMMMA
return self == GemmInst.WGMMA
def is_tcgen5mma(self) -> bool:
return self == GemmInst.TCGEN5MMA
def is_mfma(self) -> bool:
return self == GemmInst.MFMA
def __repr__(self) -> str:
return self.name
@tvm.ffi.register_object("tl.GemmPy")
@tvm_ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable):
A: tir.Buffer
B: tir.Buffer
C: tir.Buffer
APtr: tir.PrimExpr
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr
M: int
N: int
K: int
trans_A: bool
trans_B: bool
stride_A: int
stride_B: int
offset_A: int
offset_B: int
clear_accum: bool
k_pack: int
wg_wait: int
policy: GemmWarpPolicy
# FFI fields (LLVM/MLIR-style lowerCamel via reflection):
# a, b, c, aPtr, bPtr, cPtr, m, n, k, transA, transB,
# strideA, strideB, offsetA, offsetB, clearAccum, kPack, wgWait, policy
#
# Backward-compat alias properties are provided below to support old names.
# Backward-compat alias properties (old API → new FFI fields)
@property
def A(self):
return self.a
@property
def B(self):
return self.b
@property
def C(self):
return self.c
@property
def APtr(self):
return self.aPtr
@property
def BPtr(self):
return self.bPtr
@property
def CPtr(self):
return self.cPtr
@property
def M(self):
return self.m
@property
def N(self):
return self.n
@property
def K(self):
return self.k
@property
def trans_A(self):
return self.transA
@property
def trans_B(self):
return self.transB
@property
def stride_A(self):
return self.strideA
@property
def stride_B(self):
return self.strideB
@property
def offset_A(self):
return self.offsetA
@property
def offset_B(self):
return self.offsetB
@property
def clear_accum(self):
return self.clearAccum
@property
def k_pack(self):
return self.kPack
@property
def wg_wait(self):
return self.wgWait
def infer_layout(self, target: Target, thread_nums: int):
"""Infer the layout for the GEMM operation based on target architecture."""
gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst)
impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).infer_layout(target, thread_nums)
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
"""Lower the GEMM operation to TIR statements based on target architecture."""
gemm_inst = self._select_gemm_instruction(thread_nums, target)
impl_class = self._get_implementation_class(gemm_inst)
impl_class = self._get_implementation_class(gemm_inst, target)
return impl_class(self).lower(layout_map, target, thread_nums, thread_var)
def _select_gemm_instruction(self, thread_nums: int, target: Target) -> GemmInst:
......@@ -97,7 +163,7 @@ class GemmPy(Node, Scriptable):
"""
return GemmInst(_ffi_api.GemmPyGemmInst(self, int(thread_nums), target))
def _get_implementation_class(self, gemm_inst: GemmInst):
def _get_implementation_class(self, gemm_inst: GemmInst, target: Target):
"""Get the appropriate implementation class for the given GEMM instruction.
Args:
......@@ -111,10 +177,16 @@ class GemmPy(Node, Scriptable):
ValueError: If the instruction type is unknown
"""
if gemm_inst.is_mma():
if target_is_volta(target):
return GemmMMASm70
return GemmMMA
elif gemm_inst.is_wgmma():
return GemmWGMMA
elif gemm_inst.is_tcgen5mma():
return GemmTCGEN5
elif gemm_inst.is_mfma():
raise NotImplementedError("MFMA is not implemented")
return GemmMFMA
elif gemm_inst.is_tcgen5mma():
raise NotImplementedError("TCGEN5MMA is not implemented")
else:
raise ValueError(f"Unsupported GEMM instruction: {gemm_inst}")
......@@ -32,23 +32,23 @@ class GemmBase:
@property
def M(self) -> int:
return self.gemm_node.M
return getattr(self.gemm_node, "m", None)
@property
def N(self) -> int:
return self.gemm_node.N
return getattr(self.gemm_node, "n", None)
@property
def K(self) -> int:
return self.gemm_node.K
return getattr(self.gemm_node, "k", None)
@property
def trans_A(self) -> bool:
return self.gemm_node.trans_A
return getattr(self.gemm_node, "transA", None)
@property
def trans_B(self) -> bool:
return self.gemm_node.trans_B
return getattr(self.gemm_node, "transB", None)
@property
def in_dtype(self) -> str:
......@@ -65,56 +65,100 @@ class GemmBase:
@property
def A(self) -> tir.Buffer:
return self.gemm_node.A
return getattr(self.gemm_node, "a", None)
@property
def B(self) -> tir.Buffer:
return self.gemm_node.B
return getattr(self.gemm_node, "b", None)
@property
def C(self) -> tir.Buffer:
return self.gemm_node.C
return getattr(self.gemm_node, "c", None)
@property
def APtr(self) -> tir.PrimExpr:
return self.gemm_node.APtr
def ARegion(self):
return getattr(self.gemm_node, "aRegion", None)
@property
def BPtr(self) -> tir.PrimExpr:
return self.gemm_node.BPtr
def BRegion(self):
return getattr(self.gemm_node, "bRegion", None)
@property
def CPtr(self) -> tir.PrimExpr:
return self.gemm_node.CPtr
def CRegion(self):
return getattr(self.gemm_node, "cRegion", None)
@property
def stride_A(self) -> int:
return self.gemm_node.stride_A
return getattr(self.gemm_node, "strideA", None)
@property
def stride_B(self) -> int:
return self.gemm_node.stride_B
return getattr(self.gemm_node, "strideB", None)
@property
def offset_A(self) -> int:
return self.gemm_node.offset_A
return getattr(self.gemm_node, "offsetA", None)
@property
def offset_B(self) -> int:
return self.gemm_node.offset_B
return getattr(self.gemm_node, "offsetB", None)
@property
def clear_accum(self) -> PrimExpr:
return self.gemm_node.clear_accum
return getattr(self.gemm_node, "clearAccum", None)
@property
def k_pack(self) -> int:
return self.gemm_node.k_pack
return getattr(self.gemm_node, "kPack", None)
@property
def wg_wait(self) -> int:
return self.gemm_node.wg_wait
return getattr(self.gemm_node, "wgWait", 0)
@property
def policy(self) -> GemmWarpPolicy:
return self.gemm_node.policy
return getattr(self.gemm_node, "policy", None)
@property
def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32"))
@property
def C_coords(self):
coords = getattr(self.gemm_node, "cCoords", None)
if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32")
return [zero, zero]
return [coords[i] for i in range(len(coords))]
def get_region_base_offsets(self, region):
"""
Get the base offset (start index) for each dimension from a BufferRegion.
For example, if region is A_shared[ko % 2, 0:128, 0:64],
this returns [ko % 2, 0, 0]
Args:
region: BufferRegion object
Returns:
List of PrimExpr representing the base offset for each dimension
"""
if region is None:
return []
return [r.min for r in region.region]
@property
def A_base_offsets(self):
"""Get base offsets for each dimension of A region"""
return self.get_region_base_offsets(self.ARegion)
@property
def B_base_offsets(self):
"""Get base offsets for each dimension of B region"""
return self.get_region_base_offsets(self.BRegion)
@property
def C_base_offsets(self):
"""Get base offsets for each dimension of C region"""
return self.get_region_base_offsets(self.CRegion)
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMFMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
elif self.is_gemm_sr():
return {
self.A: make_swizzled_layout(self.A),
self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
self.B: make_swizzled_layout(self.B),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
elif self.is_gemm_rr():
return {
self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mfma_emitter.warp_rows
warp_cols = mfma_emitter.warp_cols
local_size_a = mfma_emitter.local_size_a
local_size_b = mfma_emitter.local_size_b
block_K = mfma_emitter.chunk
micro_size_k = mfma_emitter.micro_size_k
# Use region for shared-memory operands if available
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_region,
ki,
)
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_region,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_srr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_region,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_region,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
......@@ -2,7 +2,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
......@@ -83,12 +83,22 @@ class GemmMMA(GemmBase):
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss():
@T.prim_func
......@@ -100,30 +110,31 @@ class GemmMMA(GemmBase):
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
A_region,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_srr() -> None:
......@@ -135,16 +146,17 @@ class GemmMMA(GemmBase):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
if clear_accum:
T.clear(C_buf)
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
A_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......@@ -152,7 +164,7 @@ class GemmMMA(GemmBase):
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
......@@ -162,28 +174,29 @@ class GemmMMA(GemmBase):
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
def _gemm_rrr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
......@@ -192,11 +205,11 @@ class GemmMMA(GemmBase):
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
return _Simplify(_gemm_rrr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
......
# for Volta GPUs, which use legacy MMA instructions
from .gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmMMASm70(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
a_is_k_major = not self.trans_A
b_is_k_major = self.trans_B
if self.is_gemm_ss():
return {
self.A: make_volta_swizzled_layout(self.A, is_a=True, k_inner=a_is_k_major),
self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
return {
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_volta_swizzled_layout(self.B, is_a=False, k_inner=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
# Use region for shared-memory operands when applicable
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_region,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs():
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
from .gemm_base import GemmBase
from tilelang.layout import make_tcgen05mma_swizzled_layout
from tilelang.intrinsics.tcgen05_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
from tvm import tir
from tvm.target import Target
_FLOAT8_DTYPES = {
"float8_e4m3",
"float8_e4m3fn",
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fn",
"float8_e5m2fnuz",
}
class GemmTCGEN5(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
a_is_k_major = not self.trans_A
b_is_k_major = self.trans_B
if self.is_gemm_ss():
a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp
b_continuity = self.K if b_is_k_major else self.N // n_warp
return {
# WGMMA does not support padding
self.A:
make_tcgen05mma_swizzled_layout(
self.A, continuity=a_continuity, k_major=a_is_k_major),
self.B:
make_tcgen05mma_swizzled_layout(
self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C:
mma_emitter.make_mma_store_layout(self.C),
}
# No special swizzle requirement; rely on existing layout.
return {}
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
if self.A in layout_map:
mma_emitter._assign_a_shared_layout(layout_map[self.A])
if self.B in layout_map:
mma_emitter._assign_b_shared_layout(layout_map[self.B])
if not self.is_gemm_ss():
raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got "
f"A scope {self.A.scope()}, B scope {self.B.scope()}")
if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}:
raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}")
if self.B.scope() not in {"shared", "shared.dyn"}:
raise ValueError(f"Unsupported B scope for TCGEN5MMA: {self.B.scope()}")
if self.C.scope() != "shared.tmem":
raise ValueError(f"TCGEN5MMA expects C in shared.tmem, got {self.C.scope()}")
if self.wg_wait != -1:
raise ValueError("TCGEN5MMA currently requires wg_wait == -1")
mbarptr = self.mbarptr
if mbarptr == 0:
raise ValueError("TCGEN5MMA requires a valid mbarrier pointer")
C_coords = self.C_coords
if len(C_coords) != 2:
raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access")
accum_dtype = str(self.C.dtype)
if accum_dtype != "float32":
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.ARegion
B_shared = self.BRegion
C_local = self.C
clear_accum = self.clear_accum
mbar = self.mbarptr
@T.prim_func
def _gemm_ss() -> None:
if thread_var // 32 == 0:
mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbar, clear_accum)
return _Simplify(_gemm_ss, inline_let=True)
......@@ -87,12 +87,24 @@ class GemmWGMMA(GemmBase):
if self.B in layout_map:
mma_emitter._assign_b_shared_layout(layout_map[self.B])
A_shared = self.A
B_shared = self.B
C_local = self.C
# Get base offsets from regions
# All dimensions may have offsets, including the matrix dimensions
# However, for WGMMA, we pass the Buffer directly and handle offsets
# through proper indexing in the access_ptr call or buffer slicing
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
clear_accum = self.clear_accum
wg_wait = self.wg_wait
if self.is_gemm_ss():
# For WGMMA, we need to handle buffer region offsets
# If there are offsets, we create a BufferLoad inside the prim_func
# to properly generate offset access
@T.prim_func
def _gemm_ssr() -> None:
......@@ -101,14 +113,13 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
# Perform Matrix Multiplication
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum)
# Perform Matrix Multiplication with offset consideration
mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
......@@ -117,7 +128,7 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum)
mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment