Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""FFI APIs"""
import tvm.ffi
tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access
......@@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return _ffi_api.AxisSpatial( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype)
_as_range(dom), binding, dtype
)
@staticmethod
def reduce(
......@@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return _ffi_api.AxisReduce( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype)
_as_range(dom), binding, dtype
)
@staticmethod
def scan(
......@@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return _ffi_api.AxisScan( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype)
_as_range(dom), binding, dtype
)
@staticmethod
def opaque(
......@@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return _ffi_api.AxisOpaque( # type: ignore[attr-defined] # pylint: disable=no-member
_as_range(dom), binding, dtype)
_as_range(dom), binding, dtype
)
@staticmethod
def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]:
......@@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name
The iteration variables.
"""
iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] # pylint: disable=no-member
kinds, bindings, dtype)
kinds, bindings, dtype
)
return iter_vars[0] if len(iter_vars) == 1 else iter_vars
S = spatial # pylint: disable=invalid-name
R = reduce # pylint: disable=invalid-name
def serial(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement.
Parameters
......@@ -700,10 +702,7 @@ def serial(start: PrimExpr,
return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def parallel(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement.
Parameters
......@@ -731,10 +730,7 @@ def parallel(start: PrimExpr,
return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def vectorized(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement.
Parameters
......@@ -762,10 +758,7 @@ def vectorized(start: PrimExpr,
return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
def unroll(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
......@@ -837,7 +830,8 @@ def thread_binding(
else:
start = 0
return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member
start, stop, thread, annotations)
start, stop, thread, annotations
)
def grid(*extents: PrimExpr) -> frame.ForFrame:
......@@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d
def LetStmt( # pylint: disable=invalid-name
value: PrimExpr,
type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name
*,
var: Optional[Var] = None, # pylint: disable=redefined-outer-name
value: PrimExpr,
type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name
*,
var: Optional[Var] = None, # pylint: disable=redefined-outer-name
) -> frame.LetFrame:
"""Create a LetStmt binding
......@@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name
def Let( # pylint: disable=invalid-name
expr: PrimExpr,
where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name
expr: PrimExpr,
where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name
) -> PrimExpr:
"""Create a Let expression binding"""
assert len(where) == 1, "T.Let only allows `where` to have exactly one element"
......@@ -980,7 +974,8 @@ def realize(
The result RealizeFrame.
"""
return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member
buffer_slice, storage_scope, condition)
buffer_slice, storage_scope, condition
)
def allocate(
......@@ -1012,7 +1007,8 @@ def allocate(
if isinstance(condition, bool):
condition = IntImm("bool", condition)
return _ffi_api.Allocate( # type: ignore[attr-defined] # pylint: disable=no-member
extents, dtype, scope, condition, annotations)
extents, dtype, scope, condition, annotations
)
def allocate_const(
......@@ -1048,7 +1044,8 @@ def allocate_const(
np_data = np_data.reshape(extents)
return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member
ndarray.array(np_data), dtype, extents, annotations)
ndarray.array(np_data), dtype, extents, annotations
)
def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame:
......@@ -1297,7 +1294,8 @@ def buffer_store(
if isinstance(value, bool) and buffer.dtype == "bool":
value = IntImm("bool", value)
return _ffi_api.BufferStore( # type: ignore[attr-defined] # pylint: disable=no-member
buffer, value, expr_indices)
buffer, value, expr_indices
)
def prefetch(
......@@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member
def handle(dtype: Optional[str] = None,
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
def handle(dtype: Optional[str] = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var:
"""Create a TIR var that represents a pointer.
Parameters
......@@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
res = combiner(*args)
if not isinstance(res, tuple):
res = (res,)
return CommReducer(args[:num_args // 2], args[num_args // 2:], res, identity)
return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
def index_map(
......@@ -1700,16 +1695,15 @@ def target(
The target.
"""
if not isinstance(target_config, (str, dict)):
raise ValueError(
f"T.target expected a config dict or string, but got {type(target_config)}")
raise ValueError(f"T.target expected a config dict or string, but got {type(target_config)}")
if host is not None and not isinstance(host, (str, dict, Target)):
raise ValueError("T.target expected the host to be "
"a config dict, string, or T.target, "
f"but got {type(host)}")
raise ValueError(f"T.target expected the host to be a config dict, string, or T.target, but got {type(host)}")
if isinstance(target_config, dict) and "host" in target_config and host is not None:
raise ValueError("T.target expects to either receive the host "
"as part of the target's config dictionary, "
"or as a separate argument, but not both.")
raise ValueError(
"T.target expects to either receive the host "
"as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
return Target(target_config, host)
......@@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name
self.value = value
def __iter__(self):
def f():
for i in self.value:
yield meta_var(i)
......@@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name
def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
......@@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale)
def _dtype_forward(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
......
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""Atomic operations for tilelang."""
from __future__ import annotations
import tilelang.language as T
......@@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = {
}
def atomic_max(dst: Buffer,
value: PrimExpr,
memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
......@@ -64,10 +62,7 @@ def atomic_max(dst: Buffer,
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_min(dst: Buffer,
value: PrimExpr,
memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr:
"""
Atomically update the value at dst to the minimum of its current value and value.
......@@ -112,11 +107,7 @@ def atomic_min(dst: Buffer,
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_add(dst: Buffer,
value: PrimExpr,
memory_order: str | None = None,
return_prev: bool = False,
use_tma: bool = False) -> PrimExpr:
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
......@@ -191,8 +182,7 @@ def atomic_add(dst: Buffer,
if memory_order is None:
return T.call_extern(return_type, func_name, dst, value)
else:
return T.call_extern(return_type, func_name, dst, value,
_MEMORY_ORDER_ID_MAP[memory_order])
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
......@@ -208,14 +198,12 @@ def atomic_add(dst: Buffer,
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
if return_prev:
raise NotImplementedError(
"return_prev is not supported for tile-region-based atomic operations")
raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations")
if memory_order is None:
return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0)
else:
return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma,
_MEMORY_ORDER_ID_MAP[memory_order])
return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
......
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang import tvm as tvm
......@@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int):
def inc_max_nreg(reg_count: int):
"""Increment the maximum number of registers to use.
"""
"""Increment the maximum number of registers to use."""
return set_max_nreg(reg_count, 1)
def dec_max_nreg(reg_count: int):
"""Decrement the maximum number of registers to use.
"""
"""Decrement the maximum number of registers to use."""
return set_max_nreg(reg_count, 0)
def annotate_producer_reg_dealloc(reg_count: int = 24):
"""Annotate the producer reg dealloc.
"""
"""Annotate the producer reg dealloc."""
return dec_max_nreg(reg_count)
def annotate_consumer_reg_alloc(reg_count: int = 240):
"""Annotate the consumer reg alloc.
"""
"""Annotate the consumer reg alloc."""
return inc_max_nreg(reg_count)
def no_set_max_nreg():
"""Disable the maximum register limit setting.
"""
"""Disable the maximum register limit setting."""
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
def disable_warp_group_reg_alloc():
"""Disable the warp group reg alloc.
"""
"""Disable the warp group reg alloc."""
return no_set_max_nreg()
......@@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
def get_lane_idx(
warp_size: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the logical lane index of the calling thread within a warp.
Parameters
......@@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
def get_warp_idx_sync(
warp_size: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
......@@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
def get_warp_idx(
warp_size: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the canonical warp index without synchronizing the warp.
Parameters
......@@ -429,8 +430,7 @@ def get_warp_group_idx(
args.append(warp_size_expr)
if warps_per_group_expr is not None:
if warp_size_expr is None:
raise ValueError("get_warp_group_idx expects `warp_size` when specifying "
"`warps_per_group`.")
raise ValueError("get_warp_group_idx expects `warp_size` when specifying `warps_per_group`.")
args.append(warps_per_group_expr)
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args)
......@@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
offset: int | PrimExpr = 0,
num_regs: int | PrimExpr | None = None,
dtype: str | None = None):
def warpgroup_fence_operand(
buffer_or_ptr: tir.Buffer | PrimExpr, offset: int | PrimExpr = 0, num_regs: int | PrimExpr | None = None, dtype: str | None = None
):
"""Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
......@@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr,
convert(offset),
convert(num_regs),
))
)
)
if isinstance(buffer_or_ptr, tir.Buffer):
data_ptr = buffer_or_ptr.data
......@@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
if isinstance(dim, tir.IntImm):
total_elems *= int(dim)
else:
raise ValueError(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic.")
raise ValueError("warpgroup_fence_operand requires num_regs when buffer shape is symbolic.")
bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32
elif isinstance(buffer_or_ptr, BufferRegion):
......@@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
bits_per_elem = DataType(dtype).bits
num_regs = (total_elems * bits_per_elem + 31) // 32
else:
raise ValueError(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
raise ValueError("warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic.")
return evaluate(
tir.call_intrin(
"handle",
......@@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr,
convert(offset),
convert(num_regs),
))
)
)
else:
data_ptr = buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided
......@@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
except Exception:
inferred = None
if inferred is None:
raise ValueError(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
raise ValueError("dtype must be provided when passing a pointer expression and cannot be inferred.")
dtype = inferred
if num_regs is None:
raise ValueError("num_regs must be provided when passing a pointer expression.")
......@@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr,
convert(offset),
convert(num_regs),
))
)
)
def wait_wgmma(id: int):
......@@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_xor", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xFFFFFFFF, value, offset)
def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
......@@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_down", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xFFFFFFFF, value, offset)
def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
......@@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call)
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_up", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xFFFFFFFF, value, offset)
def sync_threads(barrier_id: int = None, arrive_count: int = None):
"""Synchronize all threads in a block.
"""
"""Synchronize all threads in a block."""
args = []
if barrier_id is not None:
args.append(barrier_id)
......@@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
def sync_global():
"""Synchronize all threads in the entire grid.
"""
"""Synchronize all threads in the entire grid."""
tx, ty, tz = get_thread_bindings()
ex, ey, ez = get_block_extents()
print(tx, ty, tz, ex, ey, ez)
......@@ -724,8 +719,7 @@ def sync_global():
def sync_grid():
"""Synchronize all threads in a grid.
"""
"""Synchronize all threads in a grid."""
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))
......@@ -741,12 +735,10 @@ def initialize_wgmma_descriptor(
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1):
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
descriptor, [0])
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0])
return evaluate(
tir.call_intrin(
......@@ -757,7 +749,8 @@ def initialize_wgmma_descriptor(
layout_type_,
int(leading_byte_offset),
int(stride_byte_offset),
))
)
)
def initialize_tcgen05_descriptor(
......@@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor(
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1):
if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
descriptor, [0])
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0])
return evaluate(
tir.call_intrin(
......@@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor(
int(base_offset),
tir.IntImm("int32", 1 if leading_is_absolute else 0),
int(swizzle_mode),
))
)
)
def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr:
......@@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, tir.Buffer) and len(
descriptor.shape) != 1 or descriptor.shape[0] != 1:
if isinstance(descriptor, tir.Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
descriptor, [0])
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(descriptor, [0])
return evaluate(
tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor,
offset))
return evaluate(tir.call_intrin("handle", tir.op.Op.get("tl.increase_descriptor_offset"), descriptor, offset))
def loop_break():
"""Break out of the innermost loop.
"""
"""Break out of the innermost loop."""
return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break"))
def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc."""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Literal
from tilelang import language as T
......@@ -10,11 +11,13 @@ from tilelang.utils.language import (
from tvm import ir, tir
def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
dst: tir.Buffer | tir.BufferLoad,
coalesced_width: int | None = None,
disable_tma: bool = False,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
def copy(
src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
dst: tir.Buffer | tir.BufferLoad,
coalesced_width: int | None = None,
disable_tma: bool = False,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None,
):
"""Copy data between memory regions.
Args:
......@@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
src_extent = get_extent(src)
dst_extent = get_extent(dst)
# Combine the nested if statements into a single if statement as suggested by SIM102
if (src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and
isinstance(dst, tir.BufferLoad)):
if src_extent is None and dst_extent is None and isinstance(src, tir.BufferLoad) and isinstance(dst, tir.BufferLoad):
# check if the case is like this:
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
......@@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
eviction_policy = 0
else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width,
disable_tma, eviction_policy)
def c2d_im2col(img: tir.Buffer,
col: tir.Buffer,
nhw_step: tir.PrimExpr,
c_step: tir.PrimExpr,
kernel: int,
stride: int,
dilation: int,
pad: int,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy)
def c2d_im2col(
img: tir.Buffer,
col: tir.Buffer,
nhw_step: tir.PrimExpr,
c_step: tir.PrimExpr,
kernel: int,
stride: int,
dilation: int,
pad: int,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None,
):
"""Perform im2col transformation for 2D convolution.
Args:
......@@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
img_region = to_buffer_region(img, access_type="r")
col_region = to_buffer_region(col, access_type="w")
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region,
nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy)
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.tileop.c2d_im2col"),
img_region,
col_region,
nhw_step,
c_step,
kernel,
stride,
dilation,
pad,
eviction_policy,
)
"""The language interface for tl programs."""
from __future__ import annotations
import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, op
from tilelang.utils.language import (bits_product, prim_expr_equal)
from tilelang.utils.language import bits_product, prim_expr_equal
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
......@@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns:
Buffer: A new buffer view with the specified shape
"""
assert prim_expr_equal(
bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)
), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}"
assert prim_expr_equal(bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)), (
f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}"
)
return T.Tensor(shape, src.dtype, src.data)
......@@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N
shape = src.shape
if dtype is None:
dtype = src.dtype
assert prim_expr_equal(bits_product(shape, dtype),
bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
return T.Tensor(shape, dtype, src.data)
......
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
......@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal,
)
from tilelang.language.utils import (
buffer_region_to_tile_region,)
buffer_region_to_tile_region,
)
def gemm_sp(
......@@ -169,18 +171,19 @@ def gemm_sp_v2(
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
assert A_shape[i] == 1, (
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
assert B_shape[i] == 1, (
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M, N = C_shape
K = 2 * (A_shape[-2] if transpose_A else A_shape[-1])
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert prim_expr_equal(
K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
assert prim_expr_equal(K, K_B), f"T.gemm_sp K shape check failed: K_A (wo sparse) = {K}, K_B = {K_B}"
stride_a = A_stride[-2]
stride_b = B_stride[-2]
......
"""The language interface for tl programs."""
from __future__ import annotations
from tvm import tir
from tilelang.language import has_let_value, get_let_value
......@@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
extents = [tir.IntImm("int32", 1) for _ in buffer.indices]
else:
extents = []
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"),
to_buffer_region(buffer, access_type="w", extents=extents), value)
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), to_buffer_region(buffer, access_type="w", extents=extents), value)
def clear(buffer: tir.Buffer | tir.Var):
......@@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var):
elif isinstance(buffer_region, tir.BufferLoad):
region = get_buffer_region_from_load(buffer_region)
if region is None:
raise ValueError(
f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
return fill(region, 0)
else:
raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
......
"""Override the LetFrame to print a message when entering the frame."""
from __future__ import annotations
from tvm.ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
......@@ -29,7 +30,7 @@ class FrameStack:
item: The frame object to push onto the stack
"""
self._stack.append(item)
if hasattr(item, 'var') and hasattr(item, 'value'):
if hasattr(item, "var") and hasattr(item, "value"):
self._var_value_map[item.var] = item.value
def pop(self):
......@@ -43,7 +44,7 @@ class FrameStack:
"""
if self._stack:
item = self._stack.pop()
if hasattr(item, 'var'):
if hasattr(item, "var"):
self._var_value_map.pop(item.var, None)
return item
raise IndexError(f"{self.__class__.__name__} is empty")
......@@ -129,8 +130,7 @@ class LetFrame(TIRFrame):
is_block_load = True
break
if is_block_load:
self.value = BufferRegion(self.value.buffer,
[Range(x.base, x.lanes) for x in indices])
self.value = BufferRegion(self.value.buffer, [Range(x.base, x.lanes) for x in indices])
_get_let_stack().push(self)
return self.var
......
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
......@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal,
)
from tilelang.language.utils import (
buffer_region_to_tile_region,)
buffer_region_to_tile_region,
)
from tilelang.env import env as _env
......@@ -68,12 +70,14 @@ def _gemm_impl(
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
assert A_shape[i] == 1, (
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
assert B_shape[i] == 1, (
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
......@@ -96,9 +100,29 @@ def _gemm_impl(
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin("handle", tir.op.Op.get(op_key), A_arg, B_arg, C_arg, transpose_A,
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a,
offset_b, k_pack, wg_wait, mbar, C_coords[0], C_coords[1])
return tir.call_intrin(
"handle",
tir.op.Op.get(op_key),
A_arg,
B_arg,
C_arg,
transpose_A,
transpose_B,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
mbar,
C_coords[0],
C_coords[1],
)
# Public wrappers
......
"""The language interface for tl programs."""
from __future__ import annotations
from collections import deque
from tvm import tir
......@@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame):
_get_current_stack().push(self)
last_block_frame = self.frames[-1]
assert isinstance(last_block_frame,
BlockFrame), f"Last frame must be a block frame, got {last_block_frame}"
assert isinstance(last_block_frame, BlockFrame), f"Last frame must be a block frame, got {last_block_frame}"
maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False)
......@@ -303,56 +303,48 @@ def Kernel(
def get_thread_binding(dim: int = 0) -> Var:
"""Returns the thread binding for the given dimension.
"""
"""Returns the thread binding for the given dimension."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_binding(dim)
def get_thread_bindings() -> list[Var]:
"""Returns all three thread bindings.
"""
"""Returns all three thread bindings."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_bindings()
def get_block_binding(dim: int = 0) -> Var:
"""Returns the block binding for the given dimension.
"""
"""Returns the block binding for the given dimension."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_binding(dim)
def get_block_bindings() -> list[Var]:
"""Returns all three block bindings.
"""
"""Returns all three block bindings."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_bindings()
def get_thread_extent(dim: int = 0) -> int:
"""Returns the thread extent for the given dimension.
"""
"""Returns the thread extent for the given dimension."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> list[int]:
"""Returns all three thread extents.
"""
"""Returns all three thread extents."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_extents()
def get_block_extent(dim: int = 0) -> int:
"""Returns the block extent for the given dimension.
"""
"""Returns the block extent for the given dimension."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> list[int]:
"""Returns all three block extents.
"""
"""Returns all three block extents."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extents()
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang import language as T
......@@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion):
)
new_region.append(r.min)
buffer_load = BufferLoad(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load),
extent)
return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer_load), extent)
else:
raise ValueError(f"Invalid buffer type: {type(buffer)}")
......@@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion):
)
new_region.append(r.min)
buffer_load = BufferLoad(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load),
extent)
return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer_load), extent)
else:
raise ValueError(f"Invalid buffer type: {type(buffer)}")
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any
from tvm import tir
......@@ -94,11 +95,9 @@ def Pipelined(
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
def serial(
start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, annotations: dict[str, Any] | None = None
) -> frame.ForFrame:
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
......@@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr,
return SerialForWithStep(start, stop, step, annotations=annotations)
def unroll(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
explicit: bool = False,
unroll_factor: int | None = None,
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
def unroll(
start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
explicit: bool = False,
unroll_factor: int | None = None,
annotations: dict[str, Any] | None = None,
) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
......
......@@ -3,7 +3,7 @@ from tvm import tir
def _validate_rounding_mode(rounding_mode):
"""Validate that the rounding mode is one of the supported IEEE modes"""
valid_modes = {'rn', 'rz', 'ru', 'rd'}
valid_modes = {"rn", "rz", "ru", "rd"}
if isinstance(rounding_mode, str) and rounding_mode in valid_modes:
return
raise ValueError(f"Invalid rounding mode '{rounding_mode}'. Must be one of: {valid_modes}")
......
"""TVMScript parser overrides tailored for TileLang."""
from functools import partial
from tvm.script.ir_builder import tir as T
......@@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un
lhs.ctx = load_ctx
lhs_value = self.eval_expr(lhs)
lhs.ctx = store_ctx
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0):
if (
isinstance(lhs_value, BufferLoad)
and lhs_value.buffer.scope() == "local.var"
and len(lhs_value.indices) == 1
and lhs_value.indices[0] == 0
):
T.buffer_store(lhs_value.buffer, rhs, indices=[0])
continue
......@@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis
lhs.ctx = load_ctx
lhs_value = self.eval_expr(lhs)
lhs.ctx = store_ctx
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0):
if (
isinstance(lhs_value, BufferLoad)
and lhs_value.buffer.scope() == "local.var"
and len(lhs_value.indices) == 1
and lhs_value.indices[0] == 0
):
T.buffer_store(lhs_value.buffer, rhs, indices=[0])
return
......@@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis
lhs.ctx = load_ctx
lhs_value = self.eval_expr(lhs)
lhs.ctx = store_ctx
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0):
if (
isinstance(lhs_value, BufferLoad)
and lhs_value.buffer.scope() == "local.var"
and len(lhs_value.indices) == 1
and lhs_value.indices[0] == 0
):
T.buffer_store(lhs_value.buffer, rhs, indices=[0])
return
......
......@@ -18,6 +18,7 @@
# which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa
"""The entry point of TVM parser for tir."""
import inspect
from typing import Callable, Optional, Union
......@@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils
from tvm.script.parser.core.parser import Parser, ScriptMacro
def prim_func(func: Optional[Callable] = None,
private: bool = False,
check_well_formed=True) -> Union[PrimFunc, Callable]:
def prim_func(func: Optional[Callable] = None, private: bool = False, check_well_formed=True) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
......@@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])
raise ValueError(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
class BufferProxy:
......
......@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration"""
from tvm import tir
from tvm.ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.tir import IntImm
......@@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
return dtype[0:index]
def _auto_broadcast(a, b, op):
if isinstance(a, int):
if hasattr(b, "dtype"):
if (DataType(b.dtype).type_code == DataTypeCode.INT or
DataType(b.dtype).type_code == DataTypeCode.UINT):
if DataType(b.dtype).type_code == DataTypeCode.INT or DataType(b.dtype).type_code == DataTypeCode.UINT:
a = IntImm(_get_type_str(b.dtype), a)
elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
a = FloatImm(_get_type_str(b.dtype), a)
......@@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr."
if isinstance(b, int):
if (DataType(a.dtype).type_code == DataTypeCode.INT or
DataType(a.dtype).type_code == DataTypeCode.UINT):
if DataType(a.dtype).type_code == DataTypeCode.INT or DataType(a.dtype).type_code == DataTypeCode.UINT:
b = IntImm(_get_type_str(a.dtype), b)
elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
b = FloatImm(_get_type_str(a.dtype), b)
......@@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
return op(a, b)
elif (DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes):
elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
return op(broadcast_a, b)
elif (DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes):
elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
return op(a, broadcast_b)
else:
......
......@@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
res = value.__enter__()
IRBuilder.name(var_name, res)
return res
elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and
not self.var_table.exist(value)):
elif isinstance(value, (Buffer, IterVar)) or (isinstance(value, Var) and not self.var_table.exist(value)):
IRBuilder.name(var_name, value)
return value
else:
......@@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None:
if not isinstance(for_frame, T.frame.ForFrame):
self.report_error(
node.iter,
"Expect the for loop to be one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
"Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
)
with self.var_table.with_frame():
with for_frame as iters:
......@@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None:
for item in node.items:
frame = self.eval_expr(item.context_expr)
if not isinstance(frame, Frame):
self.report_error(item.context_expr,
"Invalid context expression in the with-statement.")
self.report_error(item.context_expr, "Invalid context expression in the with-statement.")
rhs = stack.enter_context(frame)
if item.optional_vars is not None:
self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
......@@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None:
with self.var_table.with_frame():
self.visit_body(node.orelse)
else:
self.report_error(node.test,
f"If condition must be a boolean expression, but got {predicate}")
self.report_error(node.test, f"If condition must be a boolean expression, but got {predicate}")
@dispatch.register(token="tir", type_name="Assert")
......
......@@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
@macro
def print_var_with_condition(condition: tir.PrimExpr,
var: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr:
def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
......@@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr,
@macro
def print_global_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
"""
......@@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i,
buffer[coords])
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
else:
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
@macro
def print_shared_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
def print_shared_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
......@@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i,
buffer[coords])
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
@macro
def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
def print_fragment_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
......@@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
@macro
def print_local_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
......@@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i,
buffer[coords])
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
from tilelang.utils.target import check_cuda_availability
......@@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == main_lane and ty == 0 and tz == 0)
condition = tx == main_lane and ty == 0 and tz == 0
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_fragment_buffer_with_condition(condition, buffer, elems, msg)
......@@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems *= dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == main_lane and ty == 0 and tz == 0)
condition = tx == main_lane and ty == 0 and tz == 0
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_shared_buffer_with_condition(condition, buffer, elems, msg)
......@@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
else:
# Unsupported object type.
raise ValueError(
f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.")
raise ValueError(f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.")
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar
......@@ -51,11 +52,9 @@ class BufferProxy:
return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self,
pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Buffer:
def from_ptr(
self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None
) -> Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
......@@ -76,6 +75,7 @@ class BaseTensorProxy:
customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations.
"""
default_scope = "global"
default_align = 0
default_offset_factor = 0
......@@ -118,11 +118,9 @@ class BaseTensorProxy:
keys = (keys,)
return self(*keys)
def from_ptr(self,
pointer_var: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
def from_ptr(
self, pointer_var: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None
) -> tir.Buffer:
"""Create a buffer from a pointer, shape, and data type.
Args:
......@@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy):
strides.append(s)
return tuple(reversed(strides))
def __call__(self,
shape: tuple[Any] | PrimExpr | int,
dtype: str = "float32",
data=None,
scope=None) -> tir.Buffer:
def __call__(self, shape: tuple[Any] | PrimExpr | int, dtype: str = "float32", data=None, scope=None) -> tir.Buffer:
if isinstance(shape, (int, PrimExpr)):
shape = (shape,)
return super().__call__(
shape,
dtype=dtype,
strides=TensorProxy._construct_strides(shape),
data=data,
scope=scope)
return super().__call__(shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data, scope=scope)
class StridedTensorProxy(BaseTensorProxy):
......@@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy):
This class implements the default tensor proxy with global memory scope, with the stride information required.
"""
def __call__(self,
shape: tuple[Any],
strides: tuple[Any],
dtype: str = "float32",
scope=None) -> tir.Buffer:
def __call__(self, shape: tuple[Any], strides: tuple[Any], dtype: str = "float32", scope=None) -> tir.Buffer:
if len(shape) != len(strides):
raise ValueError("Invalid shape/strides' dimensions")
return super().__call__(shape, dtype=dtype, strides=strides, scope=scope)
......@@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy):
This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations.
"""
default_scope = "local.fragment"
......@@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy):
This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations.
"""
default_scope = "shared.dyn"
......@@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy):
This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels.
"""
default_scope = "local"
......@@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name
if TYPE_CHECKING:
class BaseTensor:
def __class_getitem__(cls, key):
return cls
def __getitem__(self, key) -> Any:
...
def __getitem__(self, key) -> Any: ...
def __setitem__(self, key, value) -> None:
...
def __setitem__(self, key, value) -> None: ...
def __init__(
self,
......@@ -238,36 +223,26 @@ if TYPE_CHECKING:
offset_factor=None,
buffer_type="",
axis_separators=None,
):
...
): ...
@classmethod
def from_ptr(cls,
pointer_var: Var,
shape: Sequence[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> Self:
...
def from_ptr(
cls, pointer_var: Var, shape: Sequence[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None
) -> Self: ...
class Tensor(BaseTensor):
...
class Tensor(BaseTensor): ...
class StridedTensor(BaseTensor):
...
class StridedTensor(BaseTensor): ...
class FragmentBuffer(BaseTensor):
...
class FragmentBuffer(BaseTensor): ...
class SharedBuffer(BaseTensor):
...
class SharedBuffer(BaseTensor): ...
class LocalBuffer(BaseTensor):
...
class LocalBuffer(BaseTensor): ...
_T = TypeVar('_T')
_T = TypeVar("_T")
class Ref(Generic[_T], tir.Var):
...
class Ref(Generic[_T], tir.Var): ...
else:
Tensor = TensorProxy() # pylint: disable=invalid-name
StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name
......@@ -275,14 +250,10 @@ else:
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
class Ref:
...
class Ref: ...
def ptr(dtype: str | None = None,
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
def ptr(dtype: str | None = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var:
"""Create a TIR var that represents a pointer.
Parameters
......@@ -304,8 +275,5 @@ def ptr(dtype: str | None = None,
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)
def make_tensor(ptr: Var,
shape: tuple[PrimExpr, ...],
dtype: str = "float32",
strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer:
return Tensor.from_ptr(ptr, shape, dtype, strides)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment