Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# This file is modified from the original version, # This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/). # which is part of the TVM project (https://tvm.apache.org/).
"""FFI APIs""" """FFI APIs"""
import tvm.ffi import tvm.ffi
tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access
...@@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name ...@@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable. The iteration variable.
""" """
return _ffi_api.AxisSpatial( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.AxisSpatial( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype) _as_range(dom), binding, dtype
)
@staticmethod @staticmethod
def reduce( def reduce(
...@@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name ...@@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable. The iteration variable.
""" """
return _ffi_api.AxisReduce( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.AxisReduce( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype) _as_range(dom), binding, dtype
)
@staticmethod @staticmethod
def scan( def scan(
...@@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name ...@@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable. The iteration variable.
""" """
return _ffi_api.AxisScan( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.AxisScan( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype) _as_range(dom), binding, dtype
)
@staticmethod @staticmethod
def opaque( def opaque(
...@@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name ...@@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable. The iteration variable.
""" """
return _ffi_api.AxisOpaque( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.AxisOpaque( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype) _as_range(dom), binding, dtype
)
@staticmethod @staticmethod
def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]:
...@@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name ...@@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name
The iteration variables. The iteration variables.
""" """
iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member
kinds, bindings, dtype) kinds, bindings, dtype
)
return iter_vars[0] if len(iter_vars) == 1 else iter_vars return iter_vars[0] if len(iter_vars) == 1 else iter_vars
S = spatial # pylint: disable=invalid-name S = spatial # pylint: disable=invalid-name
R = reduce # pylint: disable=invalid-name R = reduce # pylint: disable=invalid-name
def serial(start: PrimExpr, def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement. """The serial For statement.
Parameters Parameters
...@@ -700,10 +702,7 @@ def serial(start: PrimExpr, ...@@ -700,10 +702,7 @@ def serial(start: PrimExpr,
return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def parallel(start: PrimExpr, def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement. """The parallel For statement.
Parameters Parameters
...@@ -731,10 +730,7 @@ def parallel(start: PrimExpr, ...@@ -731,10 +730,7 @@ def parallel(start: PrimExpr,
return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def vectorized(start: PrimExpr, def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement. """The vectorized For statement.
Parameters Parameters
...@@ -762,10 +758,7 @@ def vectorized(start: PrimExpr, ...@@ -762,10 +758,7 @@ def vectorized(start: PrimExpr,
return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def unroll(start: PrimExpr, def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement. """The unrolled For statement.
Parameters Parameters
...@@ -837,7 +830,8 @@ def thread_binding( ...@@ -837,7 +830,8 @@ def thread_binding(
else: else:
start = 0 start = 0
return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member
start, stop, thread, annotations) start, stop, thread, annotations
)
def grid(*extents: PrimExpr) -> frame.ForFrame: def grid(*extents: PrimExpr) -> frame.ForFrame:
...@@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d ...@@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d
def LetStmt( # pylint: disable=invalid-name def LetStmt( # pylint: disable=invalid-name
value: PrimExpr, value: PrimExpr,
type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name
*, *,
var: Optional[Var] = None, # pylint: disable=redefined-outer-name var: Optional[Var] = None, # pylint: disable=redefined-outer-name
) -> frame.LetFrame: ) -> frame.LetFrame:
"""Create a LetStmt binding """Create a LetStmt binding
...@@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name ...@@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name
def Let( # pylint: disable=invalid-name def Let( # pylint: disable=invalid-name
expr: PrimExpr, expr: PrimExpr,
where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name
) -> PrimExpr: ) -> PrimExpr:
"""Create a Let expression binding""" """Create a Let expression binding"""
assert len(where) == 1, "T.Let only allows `where` to have exactly one element" assert len(where) == 1, "T.Let only allows `where` to have exactly one element"
...@@ -980,7 +974,8 @@ def realize( ...@@ -980,7 +974,8 @@ def realize(
The result RealizeFrame. The result RealizeFrame.
""" """
return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member
buffer_slice, storage_scope, condition) buffer_slice, storage_scope, condition
)
def allocate( def allocate(
...@@ -1012,7 +1007,8 @@ def allocate( ...@@ -1012,7 +1007,8 @@ def allocate(
if isinstance(condition, bool): if isinstance(condition, bool):
condition = IntImm("bool", condition) condition = IntImm("bool", condition)
return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member
extents, dtype, scope, condition, annotations) extents, dtype, scope, condition, annotations
)
def allocate_const( def allocate_const(
...@@ -1048,7 +1044,8 @@ def allocate_const( ...@@ -1048,7 +1044,8 @@ def allocate_const(
np_data = np_data.reshape(extents) np_data = np_data.reshape(extents)
return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member
ndarray.array(np_data), dtype, extents, annotations) ndarray.array(np_data), dtype, extents, annotations
)
def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
...@@ -1297,7 +1294,8 @@ def buffer_store( ...@@ -1297,7 +1294,8 @@ def buffer_store(
if isinstance(value, bool) and buffer.dtype == "bool": if isinstance(value, bool) and buffer.dtype == "bool":
value = IntImm("bool", value) value = IntImm("bool", value)
return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member
buffer, value, expr_indices) buffer, value, expr_indices
)
def prefetch( def prefetch(
...@@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE ...@@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member
def handle(dtype: Optional[str] = None, def handle(dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var:
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
"""Create a TIR var that represents a pointer. """Create a TIR var that represents a pointer.
Parameters Parameters
...@@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: ...@@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
res = combiner(*args) res = combiner(*args)
if not isinstance(res, tuple): if not isinstance(res, tuple):
res = (res,) res = (res,)
return CommReducer(args[:num_args // 2], args[num_args // 2:], res, identity) return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
def index_map( def index_map(
...@@ -1700,16 +1695,15 @@ def target( ...@@ -1700,16 +1695,15 @@ def target(
The target. The target.
""" """
if not isinstance(target_config, (str, dict)): if not isinstance(target_config, (str, dict)):
raise ValueError( raise ValueError(f"T.target expected a config dict or string, but got {type(target_config)}")
f"T.target expected a config dict or string, but got {type(target_config)}")
if host is not None and not isinstance(host, (str, dict, Target)): if host is not None and not isinstance(host, (str, dict, Target)):
raise ValueError("T.target expected the host to be " raise ValueError(f"T.target expected the host to be a config dict, string, or T.target, but got {type(host)}")
"a config dict, string, or T.target, "
f"but got {type(host)}")
if isinstance(target_config, dict) and "host" in target_config and host is not None: if isinstance(target_config, dict) and "host" in target_config and host is not None:
raise ValueError("T.target expects to either receive the host " raise ValueError(
"as part of the target's config dictionary, " "T.target expects to either receive the host "
"or as a separate argument, but not both.") "as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
return Target(target_config, host) return Target(target_config, host)
...@@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name ...@@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name
self.value = value self.value = value
def __iter__(self): def __iter__(self):
def f(): def f():
for i in self.value: for i in self.value:
yield meta_var(i) yield meta_var(i)
...@@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name ...@@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name
def _op_wrapper(func): def _op_wrapper(func):
@functools.wraps(func) @functools.wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
if "dtype" in kwargs: if "dtype" in kwargs:
...@@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale) ...@@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale)
def _dtype_forward(func): def _dtype_forward(func):
@functools.wraps(func) @functools.wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
if "dtype" in kwargs: if "dtype" in kwargs:
......
# Copyright (c) Tile-AI Corporation. # Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
"""Atomic operations for tilelang.""" """Atomic operations for tilelang."""
from __future__ import annotations from __future__ import annotations
import tilelang.language as T import tilelang.language as T
...@@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = { ...@@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = {
} }
def atomic_max(dst: Buffer, def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr:
value: PrimExpr,
memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
""" """
Perform an atomic maximum on the value stored at dst with an optional memory-order. Perform an atomic maximum on the value stored at dst with an optional memory-order.
...@@ -64,10 +62,7 @@ def atomic_max(dst: Buffer, ...@@ -64,10 +62,7 @@ def atomic_max(dst: Buffer,
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_min(dst: Buffer, def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr:
value: PrimExpr,
memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
""" """
Atomically update the value at dst to the minimum of its current value and value. Atomically update the value at dst to the minimum of its current value and value.
...@@ -112,11 +107,7 @@ def atomic_min(dst: Buffer, ...@@ -112,11 +107,7 @@ def atomic_min(dst: Buffer,
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order]) return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_add(dst: Buffer, def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr:
value: PrimExpr,
memory_order: str | None = None,
return_prev: bool = False,
use_tma: bool = False) -> PrimExpr:
""" """
Atomically add `value` into `dst`, returning a handle to the operation. Atomically add `value` into `dst`, returning a handle to the operation.
...@@ -191,8 +182,7 @@ def atomic_add(dst: Buffer, ...@@ -191,8 +182,7 @@ def atomic_add(dst: Buffer,
if memory_order is None: if memory_order is None:
return T.call_extern(return_type, func_name, dst, value) return T.call_extern(return_type, func_name, dst, value)
else: else:
return T.call_extern(return_type, func_name, dst, value, return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
_MEMORY_ORDER_ID_MAP[memory_order])
if isinstance(dst, Buffer) and isinstance(value, Buffer): if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape) ir.assert_structural_equal(dst.shape, value.shape)
...@@ -208,14 +198,12 @@ def atomic_add(dst: Buffer, ...@@ -208,14 +198,12 @@ def atomic_add(dst: Buffer,
# Note: tile-region-based atomic operations don't support return_prev yet # Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime # This would need to be implemented in the tile runtime
if return_prev: if return_prev:
raise NotImplementedError( raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations")
"return_prev is not supported for tile-region-based atomic operations")
if memory_order is None: if memory_order is None:
return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0) return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0)
else: else:
return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order])
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int): ...@@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int):
def inc_max_nreg(reg_count: int): def inc_max_nreg(reg_count: int):
"""Increment the maximum number of registers to use. """Increment the maximum number of registers to use."""
"""
return set_max_nreg(reg_count, 1) return set_max_nreg(reg_count, 1)
def dec_max_nreg(reg_count: int): def dec_max_nreg(reg_count: int):
"""Decrement the maximum number of registers to use. """Decrement the maximum number of registers to use."""
"""
return set_max_nreg(reg_count, 0) return set_max_nreg(reg_count, 0)
def annotate_producer_reg_dealloc(reg_count: int = 24): def annotate_producer_reg_dealloc(reg_count: int = 24):
"""Annotate the producer reg dealloc. """Annotate the producer reg dealloc."""
"""
return dec_max_nreg(reg_count) return dec_max_nreg(reg_count)
def annotate_consumer_reg_alloc(reg_count: int = 240): def annotate_consumer_reg_alloc(reg_count: int = 240):
"""Annotate the consumer reg alloc. """Annotate the consumer reg alloc."""
"""
return inc_max_nreg(reg_count) return inc_max_nreg(reg_count)
def no_set_max_nreg(): def no_set_max_nreg():
"""Disable the maximum register limit setting. """Disable the maximum register limit setting."""
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg")) return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
def disable_warp_group_reg_alloc(): def disable_warp_group_reg_alloc():
"""Disable the warp group reg alloc. """Disable the warp group reg alloc."""
"""
return no_set_max_nreg() return no_set_max_nreg()
...@@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int): ...@@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: def get_lane_idx(
warp_size: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the logical lane index of the calling thread within a warp. """Return the logical lane index of the calling thread within a warp.
Parameters Parameters
...@@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: ...@@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr) return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr: def get_warp_idx_sync(
warp_size: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the canonical warp index, assuming the warp's threads are converged. """Return the canonical warp index, assuming the warp's threads are converged.
Parameters Parameters
...@@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr: ...@@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr) return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: def get_warp_idx(
warp_size: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the canonical warp index without synchronizing the warp. """Return the canonical warp index without synchronizing the warp.
Parameters Parameters
...@@ -429,8 +430,7 @@ def get_warp_group_idx( ...@@ -429,8 +430,7 @@ def get_warp_group_idx(
args.append(warp_size_expr) args.append(warp_size_expr)
if warps_per_group_expr is not None: if warps_per_group_expr is not None:
if warp_size_expr is None: if warp_size_expr is None:
raise ValueError("get_warp_group_idx expects `warp_size` when specifying " raise ValueError("get_warp_group_idx expects `warp_size` when specifying `warps_per_group`.")
"`warps_per_group`.")
args.append(warps_per_group_expr) args.append(warps_per_group_expr)
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args) return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args)
...@@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr: ...@@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent) return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, def warpgroup_fence_operand(
offset: int | PrimExpr = 0, buffer_or_ptr: tir.Buffer | PrimExpr, offset: int | PrimExpr = 0, num_regs: int | PrimExpr | None = None, dtype: str | None = None
num_regs: int | PrimExpr | None = None, ):
dtype: str | None = None):
"""Insert a warpgroup fence for the destination accumulator registers. """Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding This prevents NVCC from sinking uses of accumulator fragments past the corresponding
...@@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr, data_ptr,
convert(offset), convert(offset),
convert(num_regs), convert(num_regs),
)) )
)
if isinstance(buffer_or_ptr, tir.Buffer): if isinstance(buffer_or_ptr, tir.Buffer):
data_ptr = buffer_or_ptr.data data_ptr = buffer_or_ptr.data
...@@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
if isinstance(dim, tir.IntImm): if isinstance(dim, tir.IntImm):
total_elems *= int(dim) total_elems *= int(dim)
else: else:
raise ValueError( raise ValueError("warpgroup_fence_operand requires num_regs when buffer shape is symbolic.")
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic.")
bits_per_elem = DataType(dtype).bits bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32 num_regs = (total_elems * bits_per_elem + 31) // 32
elif isinstance(buffer_or_ptr, BufferRegion): elif isinstance(buffer_or_ptr, BufferRegion):
...@@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
bits_per_elem = DataType(dtype).bits bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32 num_regs = (total_elems * bits_per_elem + 31) // 32
else: else:
raise ValueError( raise ValueError("warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic.")
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
return evaluate( return evaluate(
tir.call_intrin( tir.call_intrin(
"handle", "handle",
...@@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr, data_ptr,
convert(offset), convert(offset),
convert(num_regs), convert(num_regs),
)) )
)
else: else:
data_ptr = buffer_or_ptr data_ptr = buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided # Try to infer dtype from common pointer expressions when not provided
...@@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
except Exception: except Exception:
inferred = None inferred = None
if inferred is None: if inferred is None:
raise ValueError( raise ValueError("dtype must be provided when passing a pointer expression and cannot be inferred.")
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype = inferred dtype = inferred
if num_regs is None: if num_regs is None:
raise ValueError("num_regs must be provided when passing a pointer expression.") raise ValueError("num_regs must be provided when passing a pointer expression.")
...@@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr, ...@@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr, data_ptr,
convert(offset), convert(offset),
convert(num_regs), convert(num_regs),
)) )
)
def wait_wgmma(id: int): def wait_wgmma(id: int):
...@@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call ...@@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call
if _IS_HIP_AVAILABLE: if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_xor", value, offset) return tir.call_extern(value.dtype, "__shfl_xor", value, offset)
else: else:
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xFFFFFFFF, value, offset)
def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
...@@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal ...@@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal
if _IS_HIP_AVAILABLE: if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_down", value, offset) return tir.call_extern(value.dtype, "__shfl_down", value, offset)
else: else:
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset)
def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
...@@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call) ...@@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call)
if _IS_HIP_AVAILABLE: if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_up", value, offset) return tir.call_extern(value.dtype, "__shfl_up", value, offset)
else: else:
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) return tir.call_extern(value.dtype, "__shfl_up_sync", 0xFFFFFFFF, value, offset)
def sync_threads(barrier_id: int = None, arrive_count: int = None): def sync_threads(barrier_id: int = None, arrive_count: int = None):
"""Synchronize all threads in a block. """Synchronize all threads in a block."""
"""
args = [] args = []
if barrier_id is not None: if barrier_id is not None:
args.append(barrier_id) args.append(barrier_id)
...@@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None): ...@@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
def sync_global(): def sync_global():
"""Synchronize all threads in the entire grid. """Synchronize all threads in the entire grid."""
"""
tx, ty, tz = get_thread_bindings() tx, ty, tz = get_thread_bindings()
ex, ey, ez = get_block_extents() ex, ey, ez = get_block_extents()
print(tx, ty, tz, ex, ey, ez) print(tx, ty, tz, ex, ey, ez)
...@@ -724,8 +719,7 @@ def sync_global(): ...@@ -724,8 +719,7 @@ def sync_global():
def sync_grid(): def sync_grid():
"""Synchronize all threads in a grid. """Synchronize all threads in a grid."""
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid")) return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))
...@@ -741,12 +735,10 @@ def initialize_wgmma_descriptor( ...@@ -741,12 +735,10 @@ def initialize_wgmma_descriptor(
if not isinstance(descriptor, (BufferLoad, tir.Buffer)): if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.") raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0])
descriptor, [0])
return evaluate( return evaluate(
tir.call_intrin( tir.call_intrin(
...@@ -757,7 +749,8 @@ def initialize_wgmma_descriptor( ...@@ -757,7 +749,8 @@ def initialize_wgmma_descriptor(
layout_type_, layout_type_,
int(leading_byte_offset), int(leading_byte_offset),
int(stride_byte_offset), int(stride_byte_offset),
)) )
)
def initialize_tcgen05_descriptor( def initialize_tcgen05_descriptor(
...@@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor( ...@@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor(
if not isinstance(descriptor, (BufferLoad, tir.Buffer)): if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.") raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0])
descriptor, [0])
return evaluate( return evaluate(
tir.call_intrin( tir.call_intrin(
...@@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor( ...@@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor(
int(base_offset), int(base_offset),
tir.IntImm("int32", 1 if leading_is_absolute else 0), tir.IntImm("int32", 1 if leading_is_absolute else 0),
int(swizzle_mode), int(swizzle_mode),
)) )
)
def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr: def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr:
...@@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx ...@@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
if not isinstance(descriptor, (BufferLoad, tir.Buffer)): if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, tir.Buffer) and len( if isinstance(descriptor, tir.Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
descriptor.shape) != 1 or descriptor.shape[0] != 1:
raise ValueError("Descriptor must be a 1D buffer of size 1.") raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0])
descriptor, [0])
return evaluate( return evaluate(tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, offset))
tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor,
offset))
def loop_break(): def loop_break():
"""Break out of the innermost loop. """Break out of the innermost loop."""
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break")) return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break"))
def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc."""
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Literal
from tilelang import language as T from tilelang import language as T
...@@ -10,11 +11,13 @@ from tilelang.utils.language import ( ...@@ -10,11 +11,13 @@ from tilelang.utils.language import (
from tvm import ir, tir from tvm import ir, tir
def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, def copy(
dst: tir.Buffer | tir.BufferLoad, src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
coalesced_width: int | None = None, dst: tir.Buffer | tir.BufferLoad,
disable_tma: bool = False, coalesced_width: int | None = None,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): disable_tma: bool = False,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None,
):
"""Copy data between memory regions. """Copy data between memory regions.
Args: Args:
...@@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, ...@@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
src_extent = get_extent(src) src_extent = get_extent(src)
dst_extent = get_extent(dst) dst_extent = get_extent(dst)
# Combine the nested if statements into a single if statement as suggested by SIM102 # Combine the nested if statements into a single if statement as suggested by SIM102
if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and if src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad):
isinstance(dst, tir.BufferLoad)):
# check if the case is like this: # check if the case is like this:
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes # copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i] # In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
...@@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, ...@@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
eviction_policy = 0 eviction_policy = 0
else: else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy)
disable_tma, eviction_policy)
def c2d_im2col(
def c2d_im2col(img: tir.Buffer, img: tir.Buffer,
col: tir.Buffer, col: tir.Buffer,
nhw_step: tir.PrimExpr, nhw_step: tir.PrimExpr,
c_step: tir.PrimExpr, c_step: tir.PrimExpr,
kernel: int, kernel: int,
stride: int, stride: int,
dilation: int, dilation: int,
pad: int, pad: int,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None,
):
"""Perform im2col transformation for 2D convolution. """Perform im2col transformation for 2D convolution.
Args: Args:
...@@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer, ...@@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
img_region = to_buffer_region(img, access_type="r") img_region = to_buffer_region(img, access_type="r")
col_region = to_buffer_region(col, access_type="w") col_region = to_buffer_region(col, access_type="w")
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region, return tir.call_intrin(
nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy) "handle",
tir.op.Op.get("tl.tileop.c2d_im2col"),
img_region,
col_region,
nhw_step,
c_step,
kernel,
stride,
dilation,
pad,
eviction_policy,
)
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, op from tvm.tir import PrimExpr, Buffer, op
from tilelang.utils.language import (bits_product, prim_expr_equal) from tilelang.utils.language import bits_product, prim_expr_equal
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
...@@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: ...@@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns: Returns:
Buffer: A new buffer view with the specified shape Buffer: A new buffer view with the specified shape
""" """
assert prim_expr_equal( assert prim_expr_equal(bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)), (
bits_product(shape, src.dtype), bits_product(src.shape, src.dtype) f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}"
), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}" )
return T.Tensor(shape, src.dtype, src.data) return T.Tensor(shape, src.dtype, src.data)
...@@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N ...@@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N
shape = src.shape shape = src.shape
if dtype is None: if dtype is None:
dtype = src.dtype dtype = src.dtype
assert prim_expr_equal(bits_product(shape, dtype), assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
return T.Tensor(shape, dtype, src.data) return T.Tensor(shape, dtype, src.data)
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
...@@ -11,7 +12,8 @@ from tilelang.utils.language import ( ...@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal, prim_expr_equal,
) )
from tilelang.language.utils import ( from tilelang.language.utils import (
buffer_region_to_tile_region,) buffer_region_to_tile_region,
)
def gemm_sp( def gemm_sp(
...@@ -169,18 +171,19 @@ def gemm_sp_v2( ...@@ -169,18 +171,19 @@ def gemm_sp_v2(
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2: if len(A_shape) > 2:
for i in range(len(A_shape) - 2): for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \ assert A_shape[i] == 1, (
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if len(B_shape) > 2: if len(B_shape) > 2:
for i in range(len(B_shape) - 2): for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \ assert B_shape[i] == 1, (
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M, N = C_shape M, N = C_shape
K = 2 * (A_shape[-2] if transpose_A else A_shape[-1]) K = 2 * (A_shape[-2] if transpose_A else A_shape[-1])
K_B = B_shape[-1] if transpose_B else B_shape[-2] K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert prim_expr_equal( assert prim_expr_equal(K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
stride_a = A_stride[-2] stride_a = A_stride[-2]
stride_b = B_stride[-2] stride_b = B_stride[-2]
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.language import has_let_value, get_let_value from tilelang.language import has_let_value, get_let_value
...@@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim ...@@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
extents = [tir.IntImm("int32", 1) for _ in buffer.indices] extents = [tir.IntImm("int32", 1) for _ in buffer.indices]
else: else:
extents = [] extents = []
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), to_buffer_region(buffer, access_type="w", extents=extents), value)
to_buffer_region(buffer, access_type="w", extents=extents), value)
def clear(buffer: tir.Buffer | tir.Var): def clear(buffer: tir.Buffer | tir.Var):
...@@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var): ...@@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var):
elif isinstance(buffer_region, tir.BufferLoad): elif isinstance(buffer_region, tir.BufferLoad):
region = get_buffer_region_from_load(buffer_region) region = get_buffer_region_from_load(buffer_region)
if region is None: if region is None:
raise ValueError( raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
return fill(region, 0) return fill(region, 0)
else: else:
raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
......
"""Override the LetFrame to print a message when entering the frame.""" """Override the LetFrame to print a message when entering the frame."""
from __future__ import annotations from __future__ import annotations
from tvm.ffi import register_object as _register_object from tvm.ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
...@@ -29,7 +30,7 @@ class FrameStack: ...@@ -29,7 +30,7 @@ class FrameStack:
item: The frame object to push onto the stack item: The frame object to push onto the stack
""" """
self._stack.append(item) self._stack.append(item)
if hasattr(item, 'var') and hasattr(item, 'value'): if hasattr(item, "var") and hasattr(item, "value"):
self._var_value_map[item.var] = item.value self._var_value_map[item.var] = item.value
def pop(self): def pop(self):
...@@ -43,7 +44,7 @@ class FrameStack: ...@@ -43,7 +44,7 @@ class FrameStack:
""" """
if self._stack: if self._stack:
item = self._stack.pop() item = self._stack.pop()
if hasattr(item, 'var'): if hasattr(item, "var"):
self._var_value_map.pop(item.var, None) self._var_value_map.pop(item.var, None)
return item return item
raise IndexError(f"{self.__class__.__name__} is empty") raise IndexError(f"{self.__class__.__name__} is empty")
...@@ -129,8 +130,7 @@ class LetFrame(TIRFrame): ...@@ -129,8 +130,7 @@ class LetFrame(TIRFrame):
is_block_load = True is_block_load = True
break break
if is_block_load: if is_block_load:
self.value = BufferRegion(self.value.buffer, self.value = BufferRegion(self.value.buffer, [Range(x.base, x.lanes) for x in indices])
[Range(x.base, x.lanes) for x in indices])
_get_let_stack().push(self) _get_let_stack().push(self)
return self.var return self.var
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
...@@ -11,7 +12,8 @@ from tilelang.utils.language import ( ...@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal, prim_expr_equal,
) )
from tilelang.language.utils import ( from tilelang.language.utils import (
buffer_region_to_tile_region,) buffer_region_to_tile_region,
)
from tilelang.env import env as _env from tilelang.env import env as _env
...@@ -68,12 +70,14 @@ def _gemm_impl( ...@@ -68,12 +70,14 @@ def _gemm_impl(
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2: if len(A_shape) > 2:
for i in range(len(A_shape) - 2): for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \ assert A_shape[i] == 1, (
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if len(B_shape) > 2: if len(B_shape) > 2:
for i in range(len(B_shape) - 2): for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \ assert B_shape[i] == 1, (
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M, N = C_shape M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1] K = A_shape[-2] if transpose_A else A_shape[-1]
...@@ -96,9 +100,29 @@ def _gemm_impl( ...@@ -96,9 +100,29 @@ def _gemm_impl(
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin("handle", tir.op.Op.get(op_key), A_arg, B_arg, C_arg, transpose_A, return tir.call_intrin(
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, "handle",
offset_b, k_pack, wg_wait, mbar, C_coords[0], C_coords[1]) tir.op.Op.get(op_key),
A_arg,
B_arg,
C_arg,
transpose_A,
transpose_B,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
mbar,
C_coords[0],
C_coords[1],
)
# Public wrappers # Public wrappers
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from collections import deque from collections import deque
from tvm import tir from tvm import tir
...@@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame):
_get_current_stack().push(self) _get_current_stack().push(self)
last_block_frame = self.frames[-1] last_block_frame = self.frames[-1]
assert isinstance(last_block_frame, assert isinstance(last_block_frame, BlockFrame), f"Last frame must be a block frame, got {last_block_frame}"
BlockFrame), f"Last frame must be a block frame, got {last_block_frame}"
maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False) maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False)
...@@ -303,56 +303,48 @@ def Kernel( ...@@ -303,56 +303,48 @@ def Kernel(
def get_thread_binding(dim: int = 0) -> Var: def get_thread_binding(dim: int = 0) -> Var:
"""Returns the thread binding for the given dimension. """Returns the thread binding for the given dimension."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_binding(dim) return KernelLaunchFrame.Current().get_thread_binding(dim)
def get_thread_bindings() -> list[Var]: def get_thread_bindings() -> list[Var]:
"""Returns all three thread bindings. """Returns all three thread bindings."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_bindings() return KernelLaunchFrame.Current().get_thread_bindings()
def get_block_binding(dim: int = 0) -> Var: def get_block_binding(dim: int = 0) -> Var:
"""Returns the block binding for the given dimension. """Returns the block binding for the given dimension."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_binding(dim) return KernelLaunchFrame.Current().get_block_binding(dim)
def get_block_bindings() -> list[Var]: def get_block_bindings() -> list[Var]:
"""Returns all three block bindings. """Returns all three block bindings."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_bindings() return KernelLaunchFrame.Current().get_block_bindings()
def get_thread_extent(dim: int = 0) -> int: def get_thread_extent(dim: int = 0) -> int:
"""Returns the thread extent for the given dimension. """Returns the thread extent for the given dimension."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extent(dim) return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> list[int]: def get_thread_extents() -> list[int]:
"""Returns all three thread extents. """Returns all three thread extents."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extents() return KernelLaunchFrame.Current().get_thread_extents()
def get_block_extent(dim: int = 0) -> int: def get_block_extent(dim: int = 0) -> int:
"""Returns the block extent for the given dimension. """Returns the block extent for the given dimension."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extent(dim) return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> list[int]: def get_block_extents() -> list[int]:
"""Returns all three block extents. """Returns all three block extents."""
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extents() return KernelLaunchFrame.Current().get_block_extents()
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from tilelang import language as T from tilelang import language as T
...@@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion): ...@@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion):
) )
new_region.append(r.min) new_region.append(r.min)
buffer_load = BufferLoad(buffer, new_region) buffer_load = BufferLoad(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), extent)
extent)
else: else:
raise ValueError(f"Invalid buffer type: {type(buffer)}") raise ValueError(f"Invalid buffer type: {type(buffer)}")
...@@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion): ...@@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion):
) )
new_region.append(r.min) new_region.append(r.min)
buffer_load = BufferLoad(buffer, new_region) buffer_load = BufferLoad(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), extent)
extent)
else: else:
raise ValueError(f"Invalid buffer type: {type(buffer)}") raise ValueError(f"Invalid buffer type: {type(buffer)}")
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from tvm import tir from tvm import tir
...@@ -94,11 +95,9 @@ def Pipelined( ...@@ -94,11 +95,9 @@ def Pipelined(
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group) return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)
def serial(start: tir.PrimExpr, def serial(
stop: tir.PrimExpr | None = None, start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, annotations: dict[str, Any] | None = None
step: tir.PrimExpr | None = None, ) -> frame.ForFrame:
*,
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
step_is_one = False step_is_one = False
step_is_one |= isinstance(step, int) and step == 1 step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1 step_is_one |= isinstance(step, IntImm) and step.value == 1
...@@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr, ...@@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr,
return SerialForWithStep(start, stop, step, annotations=annotations) return SerialForWithStep(start, stop, step, annotations=annotations)
def unroll(start: tir.PrimExpr, def unroll(
stop: tir.PrimExpr | None = None, start: tir.PrimExpr,
step: tir.PrimExpr | None = None, stop: tir.PrimExpr | None = None,
*, step: tir.PrimExpr | None = None,
explicit: bool = False, *,
unroll_factor: int | None = None, explicit: bool = False,
annotations: dict[str, Any] | None = None) -> frame.ForFrame: unroll_factor: int | None = None,
annotations: dict[str, Any] | None = None,
) -> frame.ForFrame:
"""The unrolled For statement. """The unrolled For statement.
Parameters Parameters
......
...@@ -3,7 +3,7 @@ from tvm import tir ...@@ -3,7 +3,7 @@ from tvm import tir
def _validate_rounding_mode(rounding_mode): def _validate_rounding_mode(rounding_mode):
"""Validate that the rounding mode is one of the supported IEEE modes""" """Validate that the rounding mode is one of the supported IEEE modes"""
valid_modes = {'rn', 'rz', 'ru', 'rd'} valid_modes = {"rn", "rz", "ru", "rd"}
if isinstance(rounding_mode, str) and rounding_mode in valid_modes: if isinstance(rounding_mode, str) and rounding_mode in valid_modes:
return return
raise ValueError(f"Invalid rounding mode '{rounding_mode}'. Must be one of: {valid_modes}") raise ValueError(f"Invalid rounding mode '{rounding_mode}'. Must be one of: {valid_modes}")
......
"""TVMScript parser overrides tailored for TileLang.""" """TVMScript parser overrides tailored for TileLang."""
from functools import partial from functools import partial
from tvm.script.ir_builder import tir as T from tvm.script.ir_builder import tir as T
...@@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un ...@@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un
lhs.ctx = load_ctx lhs.ctx = load_ctx
lhs_value = self.eval_expr(lhs) lhs_value = self.eval_expr(lhs)
lhs.ctx = store_ctx lhs.ctx = store_ctx
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and if (
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): isinstance(lhs_value, BufferLoad)
and lhs_value.buffer.scope() == "local.var"
and len(lhs_value.indices) == 1
and lhs_value.indices[0] == 0
):
T.buffer_store(lhs_value.buffer, rhs, indices=[0]) T.buffer_store(lhs_value.buffer, rhs, indices=[0])
continue continue
...@@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis ...@@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis
lhs.ctx = load_ctx lhs.ctx = load_ctx
lhs_value = self.eval_expr(lhs) lhs_value = self.eval_expr(lhs)
lhs.ctx = store_ctx lhs.ctx = store_ctx
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and if (
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): isinstance(lhs_value, BufferLoad)
and lhs_value.buffer.scope() == "local.var"
and len(lhs_value.indices) == 1
and lhs_value.indices[0] == 0
):
T.buffer_store(lhs_value.buffer, rhs, indices=[0]) T.buffer_store(lhs_value.buffer, rhs, indices=[0])
return return
...@@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis ...@@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis
lhs.ctx = load_ctx lhs.ctx = load_ctx
lhs_value = self.eval_expr(lhs) lhs_value = self.eval_expr(lhs)
lhs.ctx = store_ctx lhs.ctx = store_ctx
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and if (
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): isinstance(lhs_value, BufferLoad)
and lhs_value.buffer.scope() == "local.var"
and len(lhs_value.indices) == 1
and lhs_value.indices[0] == 0
):
T.buffer_store(lhs_value.buffer, rhs, indices=[0]) T.buffer_store(lhs_value.buffer, rhs, indices=[0])
return return
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
# which is part of the TVM project (https://tvm.apache.org/). # which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa # ruff: noqa
"""The entry point of TVM parser for tir.""" """The entry point of TVM parser for tir."""
import inspect import inspect
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
...@@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils ...@@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils
from tvm.script.parser.core.parser import Parser, ScriptMacro from tvm.script.parser.core.parser import Parser, ScriptMacro
def prim_func(func: Optional[Callable] = None, def prim_func(func: Optional[Callable] = None, private: bool = False, check_well_formed=True) -> Union[PrimFunc, Callable]:
private: bool = False,
check_well_formed=True) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator. """The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters Parameters
...@@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable: ...@@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
if len(args) == 1 and inspect.isfunction(args[0]): if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0]) return _decorator(args[0])
raise ValueError( raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
class BufferProxy: class BufferProxy:
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# This file is modified from the original version, # This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/). # which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration""" """The tir expression operation registration"""
from tvm import tir from tvm import tir
from tvm.ffi.runtime_ctypes import DataType, DataTypeCode from tvm.ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.tir import IntImm from tvm.tir import IntImm
...@@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name ...@@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
return dtype[0:index] return dtype[0:index]
def _auto_broadcast(a, b, op): def _auto_broadcast(a, b, op):
if isinstance(a, int): if isinstance(a, int):
if hasattr(b, "dtype"): if hasattr(b, "dtype"):
if (DataType(b.dtype).type_code == DataTypeCode.INT or if DataType(b.dtype).type_code == DataTypeCode.INT or DataType(b.dtype).type_code == DataTypeCode.UINT:
DataType(b.dtype).type_code == DataTypeCode.UINT):
a = IntImm(_get_type_str(b.dtype), a) a = IntImm(_get_type_str(b.dtype), a)
elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
a = FloatImm(_get_type_str(b.dtype), a) a = FloatImm(_get_type_str(b.dtype), a)
...@@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name ...@@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr." assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr."
if isinstance(b, int): if isinstance(b, int):
if (DataType(a.dtype).type_code == DataTypeCode.INT or if DataType(a.dtype).type_code == DataTypeCode.INT or DataType(a.dtype).type_code == DataTypeCode.UINT:
DataType(a.dtype).type_code == DataTypeCode.UINT):
b = IntImm(_get_type_str(a.dtype), b) b = IntImm(_get_type_str(a.dtype), b)
elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
b = FloatImm(_get_type_str(a.dtype), b) b = FloatImm(_get_type_str(a.dtype), b)
...@@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name ...@@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
if DataType(a.dtype).lanes == DataType(b.dtype).lanes: if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
return op(a, b) return op(a, b)
elif (DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes): elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes) broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
return op(broadcast_a, b) return op(broadcast_a, b)
elif (DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes): elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes) broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
return op(a, broadcast_b) return op(a, broadcast_b)
else: else:
......
...@@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - ...@@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
res = value.__enter__() res = value.__enter__()
IRBuilder.name(var_name, res) IRBuilder.name(var_name, res)
return res return res
elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and not self.var_table.exist(value)):
not self.var_table.exist(value)):
IRBuilder.name(var_name, value) IRBuilder.name(var_name, value)
return value return value
else: else:
...@@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None: ...@@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None:
if not isinstance(for_frame, T.frame.ForFrame): if not isinstance(for_frame, T.frame.ForFrame):
self.report_error( self.report_error(
node.iter, node.iter,
"Expect the for loop to be one of the following: " "Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
) )
with self.var_table.with_frame(): with self.var_table.with_frame():
with for_frame as iters: with for_frame as iters:
...@@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None: ...@@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None:
for item in node.items: for item in node.items:
frame = self.eval_expr(item.context_expr) frame = self.eval_expr(item.context_expr)
if not isinstance(frame, Frame): if not isinstance(frame, Frame):
self.report_error(item.context_expr, self.report_error(item.context_expr, "Invalid context expression in the with-statement.")
"Invalid context expression in the with-statement.")
rhs = stack.enter_context(frame) rhs = stack.enter_context(frame)
if item.optional_vars is not None: if item.optional_vars is not None:
self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
...@@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None: ...@@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None:
with self.var_table.with_frame(): with self.var_table.with_frame():
self.visit_body(node.orelse) self.visit_body(node.orelse)
else: else:
self.report_error(node.test, self.report_error(node.test, f"If condition must be a boolean expression, but got {predicate}")
f"If condition must be a boolean expression, but got {predicate}")
@dispatch.register(token="tir", type_name="Assert") @dispatch.register(token="tir", type_name="Assert")
......
...@@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr: ...@@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
@macro @macro
def print_var_with_condition(condition: tir.PrimExpr, def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
var: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True. Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
...@@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr, ...@@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr,
@macro @macro
def print_global_buffer_with_condition(condition: tir.PrimExpr, def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
""" """
...@@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr, ...@@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one. # Iterate through the buffer elements and print each one.
for i in serial(elems): for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape) coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
buffer[coords])
else: else:
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
@macro @macro
def print_shared_buffer_with_condition(condition: tir.PrimExpr, def print_shared_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
...@@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr, ...@@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one. # Iterate through the buffer elements and print each one.
for i in serial(elems): for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape) coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
buffer[coords])
@macro @macro
def print_fragment_buffer_with_condition(condition: tir.PrimExpr, def print_fragment_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
...@@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr, ...@@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
@macro @macro
def print_local_buffer_with_condition(condition: tir.PrimExpr, def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
""" """
Conditionally prints the values of a flattened TIR buffer if the condition is True. Conditionally prints the values of a flattened TIR buffer if the condition is True.
...@@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr, ...@@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one. # Iterate through the buffer elements and print each one.
for i in serial(elems): for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape) coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
buffer[coords])
from tilelang.utils.target import check_cuda_availability from tilelang.utils.target import check_cuda_availability
...@@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> ...@@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems *= dim elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == main_lane and ty == 0 and tz == 0) condition = tx == main_lane and ty == 0 and tz == 0
if not msg: if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>" msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_fragment_buffer_with_condition(condition, buffer, elems, msg) return print_fragment_buffer_with_condition(condition, buffer, elems, msg)
...@@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> ...@@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems *= dim elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print. # Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == main_lane and ty == 0 and tz == 0) condition = tx == main_lane and ty == 0 and tz == 0
if not msg: if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>" msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_shared_buffer_with_condition(condition, buffer, elems, msg) return print_shared_buffer_with_condition(condition, buffer, elems, msg)
...@@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> ...@@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
else: else:
# Unsupported object type. # Unsupported object type.
raise ValueError( raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.")
f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.")
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar
...@@ -51,11 +52,9 @@ class BufferProxy: ...@@ -51,11 +52,9 @@ class BufferProxy:
return self(keys) return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self, def from_ptr(
pointer_var: Var, self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None
shape: tuple[PrimExpr, ...], ) -> Buffer:
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Buffer:
"""Create a buffer from a pointer, shape, and data type. """Create a buffer from a pointer, shape, and data type.
Args: Args:
...@@ -76,6 +75,7 @@ class BaseTensorProxy: ...@@ -76,6 +75,7 @@ class BaseTensorProxy:
customizable default values for scope, alignment, and offset factors. It implements customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations. the core functionality for creating TIR buffers with specific memory configurations.
""" """
default_scope = "global" default_scope = "global"
default_align = 0 default_align = 0
default_offset_factor = 0 default_offset_factor = 0
...@@ -118,11 +118,9 @@ class BaseTensorProxy: ...@@ -118,11 +118,9 @@ class BaseTensorProxy:
keys = (keys,) keys = (keys,)
return self(*keys) return self(*keys)
def from_ptr(self, def from_ptr(
pointer_var: Var, self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None
shape: tuple[PrimExpr, ...], ) -> tir.Buffer:
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
"""Create a buffer from a pointer, shape, and data type. """Create a buffer from a pointer, shape, and data type.
Args: Args:
...@@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy): ...@@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy):
strides.append(s) strides.append(s)
return tuple(reversed(strides)) return tuple(reversed(strides))
def __call__(self, def __call__(self, shape: tuple[Any] | PrimExpr | int, dtype: str = "float32", data=None, scope=None) -> tir.Buffer:
shape: tuple[Any] | PrimExpr | int,
dtype: str = "float32",
data=None,
scope=None) -> tir.Buffer:
if isinstance(shape, (int, PrimExpr)): if isinstance(shape, (int, PrimExpr)):
shape = (shape,) shape = (shape,)
return super().__call__( return super().__call__(shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data, scope=scope)
shape,
dtype=dtype,
strides=TensorProxy._construct_strides(shape),
data=data,
scope=scope)
class StridedTensorProxy(BaseTensorProxy): class StridedTensorProxy(BaseTensorProxy):
...@@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy): ...@@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy):
This class implements the default tensor proxy with global memory scope, with the stride information required. This class implements the default tensor proxy with global memory scope, with the stride information required.
""" """
def __call__(self, def __call__(self, shape: tuple[Any], strides: tuple[Any], dtype: str = "float32", scope=None) -> tir.Buffer:
shape: tuple[Any],
strides: tuple[Any],
dtype: str = "float32",
scope=None) -> tir.Buffer:
if len(shape) != len(strides): if len(shape) != len(strides):
raise ValueError("Invalid shape/strides' dimensions") raise ValueError("Invalid shape/strides' dimensions")
return super().__call__(shape, dtype=dtype, strides=strides, scope=scope) return super().__call__(shape, dtype=dtype, strides=strides, scope=scope)
...@@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy): ...@@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy):
This class represents tensor proxies specifically for local fragment memory, This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations. typically used in GPU tensor core operations.
""" """
default_scope = "local.fragment" default_scope = "local.fragment"
...@@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy): ...@@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy):
This class represents tensor proxies for dynamic shared memory, This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations. commonly used in GPU shared memory operations.
""" """
default_scope = "shared.dyn" default_scope = "shared.dyn"
...@@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy): ...@@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy):
This class represents tensor proxies for local memory scope, This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels. typically used for temporary computations in GPU kernels.
""" """
default_scope = "local" default_scope = "local"
...@@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name ...@@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name
if TYPE_CHECKING: if TYPE_CHECKING:
class BaseTensor: class BaseTensor:
def __class_getitem__(cls, key): def __class_getitem__(cls, key):
return cls return cls
def __getitem__(self, key) -> Any: def __getitem__(self, key) -> Any: ...
...
def __setitem__(self, key, value) -> None: def __setitem__(self, key, value) -> None: ...
...
def __init__( def __init__(
self, self,
...@@ -238,36 +223,26 @@ if TYPE_CHECKING: ...@@ -238,36 +223,26 @@ if TYPE_CHECKING:
offset_factor=None, offset_factor=None,
buffer_type="", buffer_type="",
axis_separators=None, axis_separators=None,
): ): ...
...
@classmethod @classmethod
def from_ptr(cls, def from_ptr(
pointer_var: Var, cls, pointer_var: Var, shape: Sequence[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None
shape: Sequence[PrimExpr, ...], ) -> Self: ...
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Self:
...
class Tensor(BaseTensor): class Tensor(BaseTensor): ...
...
class StridedTensor(BaseTensor): class StridedTensor(BaseTensor): ...
...
class FragmentBuffer(BaseTensor): class FragmentBuffer(BaseTensor): ...
...
class SharedBuffer(BaseTensor): class SharedBuffer(BaseTensor): ...
...
class LocalBuffer(BaseTensor): class LocalBuffer(BaseTensor): ...
...
_T = TypeVar('_T') _T = TypeVar("_T")
class Ref(Generic[_T], tir.Var): class Ref(Generic[_T], tir.Var): ...
...
else: else:
Tensor = TensorProxy() # pylint: disable=invalid-name Tensor = TensorProxy() # pylint: disable=invalid-name
StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name
...@@ -275,14 +250,10 @@ else: ...@@ -275,14 +250,10 @@ else:
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
class Ref: class Ref: ...
...
def ptr(dtype: str | None = None, def ptr(dtype: str | None = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var:
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
"""Create a TIR var that represents a pointer. """Create a TIR var that represents a pointer.
Parameters Parameters
...@@ -304,8 +275,5 @@ def ptr(dtype: str | None = None, ...@@ -304,8 +275,5 @@ def ptr(dtype: str | None = None,
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var) return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)
def make_tensor(ptr: Var, def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
return Tensor.from_ptr(ptr, shape, dtype, strides) return Tensor.from_ptr(ptr, shape, dtype, strides)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment