Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
"""The language interface for tl programs."""
from __future__ import annotations
from tvm import tir
from tilelang.language import copy, macro, alloc_shared, alloc_fragment
......@@ -30,15 +31,13 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.Call: Handle to the reduction operation
"""
# input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y]
expected_shapes = [
buffer.shape[:dim] + buffer.shape[dim + 1:],
buffer.shape[:dim] + [1] + buffer.shape[dim + 1:]
]
expected_shapes = [buffer.shape[:dim] + buffer.shape[dim + 1 :], buffer.shape[:dim] + [1] + buffer.shape[dim + 1 :]]
if list(out.shape) not in expected_shapes:
expected_shapes_str = ' or '.join(map(str, expected_shapes))
expected_shapes_str = " or ".join(map(str, expected_shapes))
raise ValueError(
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}"
)
@macro
def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
......
......@@ -7,9 +7,7 @@ from tvm.tir.function import PrimFunc
from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Callable | None = None,
private: bool = False,
check_well_formed: bool = False) -> PrimFunc | Callable:
def prim_func(func: Callable | None = None, private: bool = False, check_well_formed: bool = False) -> PrimFunc | Callable:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
......@@ -113,8 +111,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])
raise ValueError(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
setattr(macro, "dispatch_token", "tir") # noqa: B010
......@@ -6,10 +6,7 @@ import tilelang.language.tir.op as _tir_op
import functools
def serial(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement.
Parameters
......@@ -31,10 +28,7 @@ def serial(start: PrimExpr,
return _ir.serial(start=start, stop=stop, annotations=annotations)
def parallel(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement.
Parameters
......@@ -56,10 +50,7 @@ def parallel(start: PrimExpr,
return _ir.parallel(start=start, stop=stop, annotations=annotations)
def vectorized(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement.
Parameters
......@@ -81,10 +72,7 @@ def vectorized(start: PrimExpr,
return _ir.vectorized(start=start, stop=stop, annotations=annotations)
def unroll(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
......@@ -161,7 +149,6 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
def _dtype_forward(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
......@@ -172,7 +159,6 @@ def _dtype_forward(func):
def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
......
from typing import TypeVar, Literal
from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm
_T = TypeVar('_T')
_T = TypeVar("_T")
def abs(x: _T, span: Span | None=None) -> _T: ...
def abs(x: _T, span: Span | None = None) -> _T: ...
def acos(x: _T) -> _T: ...
def acosh(x: _T) -> _T: ...
def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ...
def address_of(buffer_load: BufferLoad, span: Span | None = None) -> PrimExpr: ...
def asin(x: _T) -> _T: ...
def asinh(x: _T) -> _T: ...
def atan(x: _T) -> _T: ...
def atan2(x1: _T, x2: _T) -> _T: ...
def atanh(x: _T) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_not(x: _T, span: Span | None=None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ...
def ceil(x: _T, span: Span | None=None) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None = None) -> _T: ...
def bitwise_not(x: _T, span: Span | None = None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None = None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None = None) -> _T: ...
def ceil(x: _T, span: Span | None = None) -> _T: ...
def clz(x: _T) -> _T: ...
def copysign(x1: _T, x2: _T) -> _T: ...
def cos(x: _T) -> _T: ...
......@@ -25,35 +25,37 @@ def erf(x: _T) -> _T: ...
def exp(x: _T) -> _T: ...
def exp2(x: _T) -> _T: ...
def exp10(x: _T) -> _T: ...
def floor(x: _T, span: Span | None=None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floor(x: _T, span: Span | None = None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None = None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None = None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None = None) -> _T: ...
def fmod(x: _T, y: _T) -> _T: ...
def hypot(x1: _T, x2: _T) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ...
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
def isfinite(x: _T, span: Span | None=None) -> _T: ...
def isinf(x: _T, span: Span | None=None) -> _T: ...
def isnan(x: _T, span: Span | None=None) -> _T: ...
def isnullptr(x: _T, span: Span | None=None) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None = None) -> _T: ...
def infinity(dtype: _T, span: Span | None = None) -> _T: ...
def isfinite(x: _T, span: Span | None = None) -> _T: ...
def isinf(x: _T, span: Span | None = None) -> _T: ...
def isnan(x: _T, span: Span | None = None) -> _T: ...
def isnullptr(x: _T, span: Span | None = None) -> _T: ...
def ldexp(x1: _T, x2: _T) -> _T: ...
def likely(cond: _T, span: Span | None=None) -> _T: ...
def likely(cond: _T, span: Span | None = None) -> _T: ...
def log(x: _T) -> _T: ...
def log1p(x: _T) -> _T: ...
def log2(x: _T) -> _T: ...
def log10(x: _T) -> _T: ...
def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None=None) -> _T: ...
def lookup_param(param_name: str, span: Span | None = None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None = None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None = None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None = None) -> _T: ...
def nextafter(x1: _T, x2: _T) -> _T: ...
def popcount(x: _T) -> _T: ...
def pow(x: _T, y: _T, span: Span | None=None) -> _T: ...
def pow(x: _T, y: _T, span: Span | None = None) -> _T: ...
def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ...
def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ...
def q_multiply_shift_per_axis(
x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm
) -> PrimExpr: ...
def ret(val: _T) -> _T: ...
def round(x: _T, span: Span | None=None) -> _T: ...
def round(x: _T, span: Span | None = None) -> _T: ...
def rsqrt(x: _T) -> _T: ...
def shift_left(x: _T, y: _T, span=None) -> _T: ...
def shift_right(x: _T, y: _T, span=None) -> _T: ...
......@@ -63,14 +65,16 @@ def sinh(x: _T) -> _T: ...
def sqrt(x: _T) -> _T: ...
def tan(x: _T) -> _T: ...
def tanh(x: _T) -> _T: ...
def trunc(x: _T, span: Span | None=None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def trunc(x: _T, span: Span | None = None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None = None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None = None) -> _T: ...
def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ...
def tvm_throw_last_error() -> _T: ...
def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ...
def tvm_stack_make_shape(*args) -> _T: ...
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ...
def tvm_stack_make_array(
data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset
) -> PrimExpr: ...
def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ...
def call_packed(*args, span=None) -> _T: ...
def call_cpacked(*args, span=None) -> _T: ...
......@@ -80,11 +84,47 @@ def tvm_tuple(*value) -> _T: ...
def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
def tvm_thread_invariant(cond: _T) -> _T: ...
def tvm_thread_allreduce(*freduce_args) -> _T: ...
def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_load_matrix_sync(
fragment: Var,
m: IntImm,
n: IntImm,
k: IntImm,
index: PrimExpr,
buffer_ptr: PrimExpr,
stride: PrimExpr,
layout: Literal["row_major", "column_major"],
) -> PrimExpr: ...
def tvm_mma_sync(
fragment_d: Var,
index_d: PrimExpr,
fragment_a: Var,
index_a: PrimExpr,
fragment_b: Var,
index_b: PrimExpr,
fragment_c: Var,
index_c: PrimExpr,
) -> PrimExpr: ...
def tvm_bmma_sync(
fragment_d: Var,
index_d: PrimExpr,
fragment_a: Var,
index_a: PrimExpr,
fragment_b: Var,
index_b: PrimExpr,
fragment_c: Var,
index_c: PrimExpr,
) -> PrimExpr: ...
def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ...
def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_store_matrix_sync(
fragment: Var,
m: IntImm,
n: IntImm,
k: IntImm,
index: PrimExpr,
buffer_ptr: PrimExpr,
stride: PrimExpr,
layout: Literal["row_major", "column_major"],
) -> PrimExpr: ...
def ptx_wait_group(num: int) -> PrimExpr: ...
def ptx_commit_group() -> _T: ...
def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ...
......@@ -93,7 +133,7 @@ def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ...
def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ...
def create_barriers(barrier_count: int) -> PrimExpr: ...
def assume(cond: _T=None) -> _T: ...
def assume(cond: _T = None) -> _T: ...
def undef() -> _T: ...
def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ...
def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ...
......
......@@ -724,8 +724,7 @@ def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
return _tvm_op.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)
def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c,
index_c):
def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c):
"""TVM intrinsic for tensor core mma_sync operators
Parameters
......@@ -759,12 +758,10 @@ def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
fragment_c, index_c)
return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)
def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c,
index_c):
def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c):
"""TVM intrinsic for tensor core bmma_sync operators
Parameters
......@@ -798,8 +795,7 @@ def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
fragment_c, index_c)
return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)
def tvm_fill_fragment(fragment, m, n, k, index, value):
......@@ -1121,7 +1117,6 @@ def ptx_wgmma_rs(
scale_in_a,
scale_in_b,
):
return call_intrin(
dtype,
_tvm_op.Op.get("tl.ptx_wgmma_rs"),
......@@ -1345,8 +1340,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr,
smem_offset)
return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)
def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
......@@ -1381,8 +1375,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by
return _tvm_op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)
def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes,
barrier_id):
def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
......@@ -1414,8 +1407,7 @@ def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offse
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset,
bytes, barrier_id)
return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)
def ptx_commit_group():
......@@ -2951,8 +2943,7 @@ def q_multiply_shift_per_axis(
z : PrimExpr
The result.
"""
return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required,
is_rshift_required)
return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required)
def shift_left(x, y, span=None):
......@@ -3302,8 +3293,7 @@ def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dt
call : PrimExpr
The call expression.
"""
return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint,
dtype_bits_hint)
return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)
def TVMBackendFreeWorkspace(device_type, device_id, ptr):
......
......@@ -14,23 +14,18 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list
"""Convert a BufferLoad to a tl.region call with explicit extents."""
indices = list(load.indices)
if len(indices) > len(extents):
extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))
] + list(extents)
extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))] + list(extents)
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: list[tir.PrimExpr]):
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, extents: list[tir.PrimExpr]):
"""Clamp extents and return a tl.region call."""
mins = [r.min for r in buffer_region.region]
region_extents = [r.extent for r in buffer_region.region]
assert len(region_extents) >= len(extents), (
f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
)
assert len(region_extents) >= len(extents), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
clamped_extents = [
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i]
for i in range(len(region_extents))
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] for i in range(len(region_extents))
]
return region(tir.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents)
......
......@@ -5,6 +5,7 @@ from tvm import tir
from tvm.ir.expr import PrimExpr
from tvm.script.ir_builder.tir import buffer
from typing import Any, Callable, Literal, TypeVar, Generic, TYPE_CHECKING
# Python 3.9 compatibility for advanced typing features
try:
from typing import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore[attr-defined]
......@@ -37,16 +38,16 @@ from tvm.script.ir_builder import IRBuilder
import torch
import inspect
_Shapes = TypeVarTuple('_Shapes')
_Shape = ParamSpec('_Shape')
_Stride = ParamSpec('_Stride')
_DType = TypeVar('_DType')
_Shapes = TypeVarTuple("_Shapes")
_Shape = ParamSpec("_Shape")
_Stride = ParamSpec("_Stride")
_DType = TypeVar("_DType")
Scope = Literal['global', 'shared.dyn', 'local', 'local.fragment']
Scope = Literal["global", "shared.dyn", "local", "local.fragment"]
class Annot(ABC):
'''
"""
Base class for tilelang kernel annotations
Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
......@@ -54,12 +55,12 @@ class Annot(ABC):
1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
2. parse the argument value into a hash key for jit caching
3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
'''
"""
def is_kernel_arg(self) -> bool:
'''
"""
Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
'''
"""
return False
@abstractmethod
......@@ -68,29 +69,29 @@ class Annot(ABC):
@abstractmethod
def get_key_parser(self) -> Callable[[str, Any], tuple[Any, ...]]:
'''
"""
Return a parser function that converts the argument value into a hash key for jit caching
'''
"""
@abstractmethod
def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable) -> tir.Var | tir.Buffer:
'''
"""
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
'''
"""
def promote(self) -> TIRAnnot | None:
'''
"""
Try to promote the annotation into a FixedAnnot if possible
Return None if not promotable
'''
"""
return None
@dataclass
class ArgVarTable:
'''
"""
ArgVarTable is used to manage the mapping from argument names to tir.Var objects
'''
"""
var_tab: dict[str, tir.Var] = field(default_factory=dict)
tmp_name_idx: int = 0
......@@ -103,50 +104,49 @@ class ArgVarTable:
return self.var_tab[name]
def create_tmp_name(self) -> str:
name = f'varg_{self.tmp_name_idx}'
name = f"varg_{self.tmp_name_idx}"
self.tmp_name_idx += 1
return name
@dataclass
class Value(Annot):
kind: Literal['static', 'dynamic'] = 'dynamic'
kind: Literal["static", "dynamic"] = "dynamic"
name: str | None = None
dtype: dt.dtype | None = dt.int32
value: int | tir.Var | None = None
creator: Callable[[], Any] | None = None
def is_kernel_arg(self) -> bool:
return self.kind == 'dynamic'
return self.kind == "dynamic"
@classmethod
def from_value(cls, value: Any, prefer_name: str = None) -> Value:
if isinstance(value, int):
# handle A: T.Tensor[[1024, 1024], ...]
return Value(kind='static', name=prefer_name, dtype=dt.int32, value=value)
return Value(kind="static", name=prefer_name, dtype=dt.int32, value=value)
elif isinstance(value, float):
return Value(kind='static', name=prefer_name, dtype=dt.float32, value=value)
return Value(kind="static", name=prefer_name, dtype=dt.float32, value=value)
elif isinstance(value, dt.dtype):
# handle A: T.float32
return Value(kind='dynamic', name=prefer_name, dtype=value, value=None)
return Value(kind="dynamic", name=prefer_name, dtype=value, value=None)
elif isinstance(value, Value):
# handle A: T.dyn
return value
elif isinstance(value, TypeVar):
return Value(kind='static', name=value.__name__, value=None)
return Value(kind="static", name=value.__name__, value=None)
elif isinstance(value, (tir.Var, PrimExpr)):
# handle A: T.Tensor[[M, N, K], ...]
# or primexpr annotation like A: T.Tensor[[M, N * 4 +1]]
name = value.name if isinstance(value, tir.Var) else prefer_name
return Value(kind='dynamic', name=name, dtype=value.dtype, value=value)
elif value is Any or value is None or value is dt.dtype or isinstance(
value, (type,) + _GenericAliasTypes):
return Value(kind="dynamic", name=name, dtype=value.dtype, value=value)
elif value is Any or value is None or value is dt.dtype or isinstance(value, (type,) + _GenericAliasTypes):
# A # no annotation
# A: Any
# A: _T
# A: dt.dtype
# A: tuple[...]
return Value(kind='static', name=prefer_name, value=None)
return Value(kind="static", name=prefer_name, value=None)
else:
raise TypeError(f"Unsupported Value annotation: {value!r}, type: {type(value)}")
......@@ -154,7 +154,7 @@ class Value(Annot):
return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value)
def get_key_parser(self):
if self.kind == 'static':
if self.kind == "static":
if self.value is not None:
expected_value = self.value
......@@ -172,7 +172,7 @@ class Value(Annot):
return self.get_key_parser()(target)
def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable, create_arg: bool = True):
if self.kind == 'static':
if self.kind == "static":
if self.value:
assert self.value == value, f"static value mismatch for {name}: expected {self.value}, got {value}"
return value
......@@ -187,18 +187,18 @@ class Value(Annot):
return tb_tir.arg(name, arg) if create_arg else arg
def __repr__(self):
if self.kind == 'static':
if self.kind == "static":
if self.value is not None:
return repr(self.value)
else:
return (str(self.name) or '$unnamed') + '$'
return (str(self.name) or "$unnamed") + "$"
else:
if self.value is not None:
return repr(self.value)
elif self.creator is not None:
return repr(self.creator())
else:
return (str(self.name) or '$unnamed') + '$dyn'
return (str(self.name) or "$unnamed") + "$dyn"
def _canonicalize_dtype(val: Any) -> dt.dtype | None:
......@@ -226,7 +226,7 @@ def _shape_with_name(shape: Sequence[Value], base_name: str) -> list[Value]:
return None
res = []
for i, dim in enumerate(shape):
dim = dim.with_name(f'{base_name}_{i}')
dim = dim.with_name(f"{base_name}_{i}")
res.append(dim)
return res
......@@ -236,7 +236,7 @@ def _try_convert_static_shape(shape: Sequence[Value]):
return None
res = []
for s in shape:
if s.kind == 'static' and s.value is not None or s.kind == 'dynamic' and s.value is not None:
if s.kind == "static" and s.value is not None or s.kind == "dynamic" and s.value is not None:
res.append(s.value)
if len(res) == len(shape):
return res
......@@ -253,7 +253,7 @@ class BufferAnnot(Annot):
@property
def scope(self):
return 'global'
return "global"
def __call__(
self,
......@@ -290,8 +290,8 @@ class BufferAnnot(Annot):
return self.__class__(shape, strides=self.strides, dtype=dtype)
def with_name(self, name: str):
shape = _shape_with_name(self.shape, base_name=f'{name}_shape')
strides = _shape_with_name(self.strides, base_name=f'{name}_stride')
shape = _shape_with_name(self.shape, base_name=f"{name}_shape")
strides = _shape_with_name(self.strides, base_name=f"{name}_stride")
return self.__class__(shape, strides, self.dtype)
def get_key_parser(self):
......@@ -299,14 +299,14 @@ class BufferAnnot(Annot):
if self.shape is not None:
raw_shapes = False
shape_len = len(self.shape)
static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static']
static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == "static"]
# static_fixed_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static' and dim.value is not None]
# static_fixed_shape_values = [dim.value for dim in self.shape if dim.kind == 'static' and dim.value is not None]
raw_strides = True
if self.strides is not None:
raw_strides = False
strides_len = len(self.strides)
strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static']
strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == "static"]
# static_fixed_strides_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static' and dim.value is not None]
# static_fixed_strides_values = [dim.value for dim in self.strides if dim.kind == 'static' and dim.value is not None]
raw_dtype = True
......@@ -340,9 +340,7 @@ class BufferAnnot(Annot):
if not raw_dtype:
dtype = dt.dtype(dtype)
if dtype != expected_dtype:
raise TypeError(
f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}"
)
raise TypeError(f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}")
return shape, strides, dtype
return key_parser
......@@ -384,7 +382,6 @@ class BufferAnnot(Annot):
class TensorAnnot(BufferAnnot):
@staticmethod
def _construct_strides(shape: tuple[Any]):
s, strides = 1, [1]
......@@ -419,7 +416,8 @@ class TensorAnnot(BufferAnnot):
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators)
axis_separators=axis_separators,
)
def promote(self):
shape = _try_convert_static_shape(self.shape)
......@@ -430,7 +428,6 @@ class TensorAnnot(BufferAnnot):
class StridedTensorAnnot(BufferAnnot):
def __call__(
self,
shape,
......@@ -466,30 +463,27 @@ class StridedTensorAnnot(BufferAnnot):
class FragmentBufferAnnot(BufferAnnot):
@property
def scope(self):
return 'local.fragment'
return "local.fragment"
class SharedBufferAnnot(BufferAnnot):
@property
def scope(self):
return 'shared.dyn'
return "shared.dyn"
class LocalBufferAnnot(BufferAnnot):
@property
def scope(self):
return 'local'
return "local"
class DynAnnot(Value):
'''
"""
Dynamic variable annotation represents a tvm tir.Var argument
'''
"""
def __call__(self, dtype: AnyDType = dt.float32, name: str | None = None) -> DynAnnot:
return tir.Var(name, dtype)
......@@ -499,16 +493,16 @@ class DynAnnot(Value):
params = (params,)
dtype = None
if len(params) == 1:
name, = params
(name,) = params
if len(params) == 2:
dtype, name = params
dtype = _canonicalize_dtype(dtype) or dt.int32
return DynAnnot(kind='dynamic', dtype=dtype, name=name)
return DynAnnot(kind="dynamic", dtype=dtype, name=name)
@dataclass
class DTypeAnnot(Annot):
'''
"""
Data type annotation ensures automatically conversion from AnyDType to dtype
>>> def foo(A: T.dtype): print(A)
>>> foo(torch.float32)
......@@ -517,7 +511,8 @@ class DTypeAnnot(Annot):
dtype('float32')
>>> foo('float32')
dtype('float32')
'''
"""
name: str | None = None
def is_kernel_arg(self) -> bool:
......@@ -533,15 +528,16 @@ class DTypeAnnot(Annot):
return dt.dtype(value)
def __repr__(self):
return self.name + '$dtype'
return self.name + "$dtype"
@dataclass
class TIRAnnot(Annot):
'''
"""
TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments
>>> def foo(A: T.Buffer((128,), T.float32)): ...
'''
"""
data: tir.Buffer | tir.Var
def is_kernel_arg(self) -> bool:
......@@ -564,7 +560,6 @@ class TIRAnnot(Annot):
if TYPE_CHECKING:
class Buffer(Generic[_Shape, _DType]):
def __init__(
shape: tuple[Unpack[_Shapes]],
dtype: _DType = "float32",
......@@ -576,26 +571,20 @@ if TYPE_CHECKING:
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]:
...
) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]: ...
@property
def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]:
...
def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]: ...
@property
def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]:
...
def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]: ...
@property
def strides(self) -> tuple[tir.PrimExpr]:
...
def strides(self) -> tuple[tir.PrimExpr]: ...
def scope(self) -> Scope:
...
def scope(self) -> Scope: ...
class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
def __new__(
shape: tuple[Unpack[_Shapes]],
dtype: _DType = "float32",
......@@ -607,11 +596,9 @@ if TYPE_CHECKING:
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
class StridedTensor(Generic[_Shape, _Stride, _DType], Buffer[_Shape, _DType]):
def __new__(
shape: tuple[Unpack[_Shapes]],
strides=None,
......@@ -623,8 +610,7 @@ if TYPE_CHECKING:
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
class FragmentBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
pass
......@@ -636,16 +622,12 @@ if TYPE_CHECKING:
pass
class dyn(tir.Var):
def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]:
...
def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]: ...
@property
def dtype(self: dyn[_DType]) -> dt.dtype[_DType]:
...
def dtype(self: dyn[_DType]) -> dt.dtype[_DType]: ...
else:
Buffer = BufferAnnot()
Tensor = TensorAnnot()
StridedTensor = StridedTensorAnnot()
......@@ -670,7 +652,7 @@ class FuncAnnot:
ker_arg_names = []
for param in sig.parameters.values():
name = param.name
annot = func_annots.get(name, Value('static', name))
annot = func_annots.get(name, Value("static", name))
if not isinstance(annot, Annot):
if not isinstance(annot, type) and callable(annot):
annot = annot()
......@@ -679,7 +661,7 @@ class FuncAnnot:
elif isinstance(annot, (tir.Buffer, tir.Var)):
annot = TIRAnnot(data=annot)
else:
annot = Value(kind='static', name=name)
annot = Value(kind="static", name=name)
annot = annot.promote() or annot
annots[name] = annot.with_name(name)
if annot.is_kernel_arg():
......@@ -689,9 +671,9 @@ class FuncAnnot:
return FuncAnnot(sig, arg_names, annots, arg_parser, ker_arg_names)
def parse_key(self, *args, **kws):
'''
"""
Parse arguments and generates the cache key for jit caching
'''
"""
args = {name: arg for name, arg in zip(self.arg_names, args)}
arg_dict = dict(**args, **kws)
parsed = []
......@@ -706,15 +688,15 @@ class FuncAnnot:
return [arg_dict[name] for name in self.ker_arg_names]
def create_argument(self, name: str, value: Any, vt: ArgVarTable):
'''
"""
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
'''
"""
return self.annots[name].create_prim_func_arg(name, value, vt)
def is_all_static(self):
'''
"""
Check if all arguments are static (i.e., can be fully determined at compile time)
'''
"""
return all(isinstance(annot, TIRAnnot) for annot in self.annots.values())
def get_all_static_args(self):
......
......@@ -4,16 +4,18 @@ from dataclasses import dataclass
from typing import Callable, Generic, Any, Literal, TypeVar
from contextlib import AbstractContextManager
from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
import inspect
# from .utils import get_ast, get_compiled_object
from . import utils
_span_attrs = ['lineno', 'col_offset', 'end_lineno', 'end_col_offset']
_span_attrs = ["lineno", "col_offset", "end_lineno", "end_col_offset"]
def ast_has_span(ast: ast.AST) -> bool:
......@@ -34,7 +36,6 @@ def ast_set_span(ast: ast.AST, span: tuple[int, int, int, int]):
class QuoteVisitor(ast.NodeTransformer):
def __init__(self, names: dict[str, ast.AST], passes: list[Any] | None = None, span=None):
self.names = names
self.passes = passes or []
......@@ -76,9 +77,8 @@ def quote_expr(expr: str, **kws) -> ast.expr:
return res.value
Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift',
'BitOr', 'BitXor', 'BitAnd', 'FloorDiv']
BoolOp = Literal['And', 'Or', 'Not']
Operator = Literal["Add", "Sub", "Mult", "MatMult", "Div", "Mod", "Pow", "LShift", "RShift", "BitOr", "BitXor", "BitAnd", "FloorDiv"]
BoolOp = Literal["And", "Or", "Not"]
def get_operator_name(operator: ast.operator) -> Operator:
......@@ -89,84 +89,83 @@ def get_boolop_name(boolop: ast.boolop) -> BoolOp:
return boolop.__class__.__name__
_T = TypeVar('_T')
_T = TypeVar("_T")
def eval_op(op: Operator, left: Any, right: Any) -> Any:
if op == 'Add':
if op == "Add":
return left + right
if op == 'Sub':
if op == "Sub":
return left - right
if op == 'Mult':
if op == "Mult":
return left * right
if op == 'MatMult':
if op == "MatMult":
return left @ right
if op == 'Div':
if op == "Div":
return left / right
if op == 'Mod':
if op == "Mod":
return left % right
if op == 'Pow':
if op == "Pow":
return left**right
if op == 'LShift':
if op == "LShift":
return left << right
if op == 'RShift':
if op == "RShift":
return left >> right
if op == 'BitOr':
if op == "BitOr":
return left | right
if op == 'BitXor':
if op == "BitXor":
return left ^ right
if op == 'BitAnd':
if op == "BitAnd":
return left & right
if op == 'FloorDiv':
if op == "FloorDiv":
return left // right
raise ValueError(f'Unknown operator: {op}')
raise ValueError(f"Unknown operator: {op}")
def eval_aug_assign(op: Operator, left: Any, sl: slice, right: Any) -> Any:
if op == 'Add':
if op == "Add":
left[sl] += right
return left
if op == 'Sub':
if op == "Sub":
left[sl] -= right
return left
if op == 'Mult':
if op == "Mult":
left[sl] *= right
return left
if op == 'MatMult':
if op == "MatMult":
left[sl] @= right
return left
if op == 'Div':
if op == "Div":
left[sl] /= right
return left
if op == 'Mod':
if op == "Mod":
left[sl] %= right
return left
if op == 'Pow':
if op == "Pow":
left[sl] **= right
return left
if op == 'LShift':
if op == "LShift":
left[sl] <<= right
return left
if op == 'RShift':
if op == "RShift":
left[sl] >>= right
return left
if op == 'BitOr':
if op == "BitOr":
left[sl] |= right
return left
if op == 'BitXor':
if op == "BitXor":
left[sl] ^= right
return left
if op == 'BitAnd':
if op == "BitAnd":
left[sl] &= right
return left
if op == 'FloorDiv':
if op == "FloorDiv":
left[sl] //= right
return left
raise ValueError(f'Unknown operator: {op}')
raise ValueError(f"Unknown operator: {op}")
class _empty:
...
class _empty: ...
class BaseBuilder:
......@@ -218,13 +217,13 @@ class BaseBuilder:
eval_aug_assign(op, target, sl, aug_value)
def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any] | None = None) -> Any:
if op == 'And':
if op == "And":
return left and right()
if op == 'Or':
if op == "Or":
return left or right()
if op == 'Not':
if op == "Not":
return not left
raise ValueError(f'Unknown boolop: {op}')
raise ValueError(f"Unknown boolop: {op}")
def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any:
return then() if cond else otherwise()
......@@ -249,7 +248,6 @@ class BaseBuilder:
class DSLMutator(ast.NodeTransformer):
def __init__(self, closure_names: list[str]):
self.tmp_counter = 0
self.closure_names = closure_names
......@@ -264,19 +262,13 @@ class DSLMutator(ast.NodeTransformer):
br = self.get_tmp()
if len(node.orelse) == 0:
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
" pass\n",
f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n",
cond=node.test,
passes=[node.body],
span=node,
)
return quote(
f"for {br} in __tb.ctx_if(cond):\n"
f" for _ in __tb.ctx_then({br}):\n"
f" pass\n"
f" for _ in __tb.ctx_else({br}):\n"
f" pass\n",
f"for {br} in __tb.ctx_if(cond):\n for _ in __tb.ctx_then({br}):\n pass\n for _ in __tb.ctx_else({br}):\n pass\n",
cond=node.test,
passes=[node.body, node.orelse],
span=node,
......@@ -290,7 +282,7 @@ class DSLMutator(ast.NodeTransformer):
if isinstance(target, ast.Name):
return f"'{target.id}'"
elif isinstance(target, ast.Tuple):
return ("(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)")
return "(" + ",".join([self._parse_names(elt) for elt in target.elts]) + ",)"
else:
s = ast.unparse(target)
raise NotImplementedError(f"Unsupported for target `{s}`")
......@@ -303,8 +295,7 @@ class DSLMutator(ast.NodeTransformer):
ast_set_span(var, ast_get_span(node.target))
stmts = self._emit_assign_target(node.target, var)
return quote(
f"for {tmp} in __tb.ctx_for(range):\n"
" pass\n",
f"for {tmp} in __tb.ctx_for(range):\n pass\n",
target=node.target,
range=node.iter,
passes=[stmts + node.body],
......@@ -319,24 +310,15 @@ class DSLMutator(ast.NodeTransformer):
node = self.generic_visit(node)
return quote("if __tb.ctx_break(): break", span=node)
def _emit_assign_target(self,
target: ast.expr,
rval: ast.expr,
annot: ast.expr = None) -> list[ast.AST]:
def _emit_assign_target(self, target: ast.expr, rval: ast.expr, annot: ast.expr = None) -> list[ast.AST]:
if isinstance(target, ast.Name):
if annot is None:
return quote(
f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target)
return quote(f"name = __tb.bind('{target.id}', value)", name=target, value=rval, span=target)
else:
return quote(
f'name = __tb.bind("{target.id}", value, annot)',
name=target,
value=rval,
annot=annot,
span=target)
return quote(f'name = __tb.bind("{target.id}", value, annot)', name=target, value=rval, annot=annot, span=target)
elif isinstance(target, ast.Attribute):
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`")
elif isinstance(target, ast.Subscript):
if annot is None:
return quote(
......@@ -356,7 +338,6 @@ class DSLMutator(ast.NodeTransformer):
span=target,
)
else:
# flatten nested tuple into a list of (tmp_name, target)
unpacked = []
......@@ -374,11 +355,9 @@ class DSLMutator(ast.NodeTransformer):
return res
else:
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
raise NotImplementedError(f"Attribute assignment not supported yet, `{s}`")
unpack_stmt = ast.Assign(
targets=[_visit_target(target)],
value=quote_expr('__tb.unwrap_value(rval)', rval=rval, span=rval))
unpack_stmt = ast.Assign(targets=[_visit_target(target)], value=quote_expr("__tb.unwrap_value(rval)", rval=rval, span=rval))
ast_set_span(unpack_stmt, ast_get_span(target))
stmts = [unpack_stmt]
bind_lvals = []
......@@ -386,8 +365,7 @@ class DSLMutator(ast.NodeTransformer):
def flush_binds():
if bind_lvals:
stmts.append(
quote1(f'{", ".join(bind_lvals)}, = {", ".join(bind_rvals)},', span=target))
stmts.append(quote1(f"{', '.join(bind_lvals)}, = {', '.join(bind_rvals)},", span=target))
bind_lvals.clear()
bind_rvals.clear()
......@@ -417,15 +395,10 @@ class DSLMutator(ast.NodeTransformer):
bind_rvals.append(f'__tb.bind("{target.id}", {tmp})')
elif isinstance(target, ast.Subscript):
flush_binds()
stmts.append(
quote1(
f'__tb.assign_slice(lval, slice, {tmp})',
lval=target.value,
slice=target.slice,
span=target))
stmts.append(quote1(f"__tb.assign_slice(lval, slice, {tmp})", lval=target.value, slice=target.slice, span=target))
else:
s = ast.unparse(target)
raise NotImplementedError(f'Unsupported target: {s}')
raise NotImplementedError(f"Unsupported target: {s}")
flush_binds()
return stmts
......@@ -450,11 +423,7 @@ class DSLMutator(ast.NodeTransformer):
target, rval = node.target, node.value
op = get_operator_name(node.op)
if isinstance(target, ast.Name):
return quote(
f"name = __tb.aug_assign('{op}', {target.id}, value)",
name=target,
value=rval,
span=node)
return quote(f"name = __tb.aug_assign('{op}', {target.id}, value)", name=target, value=rval, span=node)
elif isinstance(target, ast.Subscript):
return quote(
f"__tb.aug_assign_slice('{op}', lval, slice, value)",
......@@ -468,16 +437,12 @@ class DSLMutator(ast.NodeTransformer):
def visit_AnnAssign(self, node: ast.AnnAssign):
node = self.generic_visit(node)
rval = node.value or quote_expr('__tb.empty', span=node, annot=node)
rval = node.value or quote_expr("__tb.empty", span=node, annot=node)
return self._emit_assign_target(node.target, rval, annot=node.annotation)
def visit_While(self, node):
node = self.generic_visit(node)
return quote1(
"for _ in __tb.ctx_while(lambda: cond):\n pass",
cond=node.test,
passes=[node.body],
span=node)
return quote1("for _ in __tb.ctx_while(lambda: cond):\n pass", cond=node.test, passes=[node.body], span=node)
def visit_FunctionDef(self, node: ast.FunctionDef):
node = self.generic_visit(node)
......@@ -536,18 +501,14 @@ class DSLMutator(ast.NodeTransformer):
left = comp
last = split[-1]
for i in reversed(range(len(split) - 1)):
last = quote_expr(
"__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node)
last = quote_expr("__tb.boolop('And', left, lambda: right)", left=split[i], right=last, span=node)
return last
def visit_IfExp(self, node: ast.IfExp) -> ast.Expr:
node = self.generic_visit(node)
return quote_expr(
'__tb.ifexp(cond, lambda: then, lambda: otherwise)',
cond=node.test,
then=node.body,
otherwise=node.orelse,
span=node)
"__tb.ifexp(cond, lambda: then, lambda: otherwise)", cond=node.test, then=node.body, otherwise=node.orelse, span=node
)
def visit_Return(self, node: ast.Return):
node = self.generic_visit(node)
......@@ -569,7 +530,7 @@ class DSLMutator(ast.NodeTransformer):
return node
_P = ParamSpec('_P')
_P = ParamSpec("_P")
@dataclass
......@@ -626,7 +587,7 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]:
make_closure = utils.get_compiled_object(
tree,
'make_closure',
"make_closure",
filename,
func.__globals__, # use the original globalns
)
......
......@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, U
from collections.abc import Sequence
from .annot import FuncAnnot, ArgVarTable, Annot
import pprint
# Python 3.9 compatibility for ParamSpec and Self
try:
from typing import ParamSpec, Self
......@@ -32,9 +33,9 @@ logger = logging.getLogger(__name__)
def unwrap_expr(expr) -> PrimExpr | int | float:
'''
"""
unwrap expr and convert it into PrimExpr like
'''
"""
if isinstance(expr, tir.meta_var):
expr = expr.value
elif isinstance(expr, Ref):
......@@ -47,9 +48,9 @@ def unwrap_expr(expr) -> PrimExpr | int | float:
def unwrap_cond(expr):
'''
"""
unwrap expr and convert to bool condition
'''
"""
expr = unwrap_expr(expr)
if isinstance(expr, (IntImm, FloatImm, StringImm)):
return bool(expr.value)
......@@ -61,10 +62,10 @@ def unwrap_cond(expr):
return bool(expr)
else:
logger.warning(
f"Python expression `{expr}` is used as condition in TileLang, \n"
"this is treated as a constant expression. ",
f"Python expression `{expr}` is used as condition in TileLang, \nthis is treated as a constant expression. ",
stack_info=True,
stacklevel=3)
stacklevel=3,
)
return bool(expr)
......@@ -72,44 +73,35 @@ thread_local_storage = threading.local()
class Frame:
'''
"""
Frame are virtual context managers used in frontend only
They do not have any runtime representation in the generated TIR.
'''
"""
def __enter__(self):
...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback):
...
def __exit__(self, exc_type, exc_value, traceback): ...
class MacroFrame(Frame):
...
class MacroFrame(Frame): ...
class ExitedMacroFrame(Frame):
...
class ExitedMacroFrame(Frame): ...
class BoolOpFrame(Frame):
...
class BoolOpFrame(Frame): ...
class ConstIfFrame(Frame):
...
class ConstIfFrame(Frame): ...
class BlockFrame(Frame):
...
class BlockFrame(Frame): ...
class ContinueFrame(Frame):
...
class ContinueFrame(Frame): ...
class BreakFrame(Frame):
...
class BreakFrame(Frame): ...
@dataclass
......@@ -145,8 +137,7 @@ class Ref:
return self.bufload
class UnrollForWithStep(SerialForWithStep):
...
class UnrollForWithStep(SerialForWithStep): ...
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
......@@ -172,11 +163,10 @@ TIR_VAR_SCOPE_FRAME = (
def is_var(v: Any) -> bool:
return isinstance(v, Buffer) and v.scope() == 'local.var'
return isinstance(v, Buffer) and v.scope() == "local.var"
class Builder(BaseBuilder):
def __init__(self, func_annot: FuncAnnot = None):
self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder()
......@@ -189,7 +179,7 @@ class Builder(BaseBuilder):
@classmethod
def current(cls) -> Self:
builder = getattr(thread_local_storage, 'builder', None)
builder = getattr(thread_local_storage, "builder", None)
return builder
@contextmanager
......@@ -199,14 +189,15 @@ class Builder(BaseBuilder):
tir.func_name(name)
yield
if len(self.out_idx) != self.out_tensor_cnt:
raise RuntimeError('Not all tensor allocated from `T.empty` are returned')
raise RuntimeError("Not all tensor allocated from `T.empty` are returned")
@contextmanager
def macro(self, name=None, annotations=None):
if self.find_frame_idx(BoolOpFrame) is not None:
raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs"
)
save = self.name_inside_frame, self.macro_arg_annot
self.name_inside_frame = {}
self.macro_arg_annot = annotations or {}
......@@ -244,10 +235,7 @@ class Builder(BaseBuilder):
def check_continue_break(self):
idx = self.find_frame_idx(ContinueOrBreak)
if idx is not None:
logger.warning(
'Writing code after continue/break may cause undefined behavior in tilelang.',
stack_info=True,
stacklevel=3)
logger.warning("Writing code after continue/break may cause undefined behavior in tilelang.", stack_info=True, stacklevel=3)
@contextmanager
def with_frame(self, frame: AbstractContextManager[Any] | None):
......@@ -256,8 +244,7 @@ class Builder(BaseBuilder):
while len(self.frames) > pop_idx:
self.frames.pop().__exit__(None, None, None)
class _has_if_frame:
...
class _has_if_frame: ...
def ctx_if(self, cond):
self.check_continue_break()
......@@ -294,7 +281,7 @@ class Builder(BaseBuilder):
elif isinstance(val, tir.frame.IRBuilderFrame):
if isinstance(val, tir.frame.ForFrame):
logger.warning(
'Evaluating a for frame may cause undefined behavior in tilelang.',
"Evaluating a for frame may cause undefined behavior in tilelang.",
stack_info=True,
stacklevel=1,
)
......@@ -310,8 +297,7 @@ class Builder(BaseBuilder):
elif isinstance(val, (Buffer, Var)):
pass
else:
logger.warning(
f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2)
logger.warning(f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2)
def ctx_for(self, it):
self.check_continue_break()
......@@ -321,15 +307,13 @@ class Builder(BaseBuilder):
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
if step_value == 0:
raise ValueError('Invalid stepped serial: step must be non-zero')
raise ValueError("Invalid stepped serial: step must be non-zero")
if step_value > 0:
real_stop = tir.ceildiv(it.stop - it.start, step_value)
else:
real_stop = tir.ceildiv(it.start - it.stop, -step_value)
else:
logger.warning(
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
logger.warning(f"Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang")
real_stop = tir.ceildiv(it.stop - it.start, it.step)
if isinstance(it, UnrollForWithStep):
real_frame = tir.unroll(real_stop, annotations=it.annotations)
......@@ -338,15 +322,17 @@ class Builder(BaseBuilder):
else:
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding")
"range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding"
)
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
IRBuilder.name("_tmp", v)
yield it.start + v * it.step
else:
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding"
)
with self.with_frame(it) as v:
yield v
......@@ -369,15 +355,16 @@ class Builder(BaseBuilder):
if not isinstance(cond_v_unwrap, PrimExpr):
if cond_v_unwrap:
raise RuntimeError(
f'Infinite while loop detected in TileLang\n'
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n'
f"Infinite while loop detected in TileLang\n"
f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n"
)
else:
logger.warning(
'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
"While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n",
f"Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n",
stack_info=True,
stacklevel=2)
stacklevel=2,
)
with self.with_frame(tir.While(cond_v_unwrap)):
yield None
......@@ -406,14 +393,14 @@ class Builder(BaseBuilder):
# 2. Quick return for trivil types
if isinstance(value, (tuple, list, tvm.ffi.Array, int, float, str)):
return value
if isinstance(value, tir.IntImm) and value.dtype == 'int32':
if isinstance(value, tir.IntImm) and value.dtype == "int32":
return value.value
if isinstance(value, (Var, Buffer)):
# Bind TVM Var/Buffer names and also record scope so reusing the same
# Python name (e.g., loop vars like `i`) across different for-frames
# works without triggering out-of-scope errors.
IRBuilder.name(name, value)
if name != '_':
if name != "_":
frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME)
assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
self.name_inside_frame[name] = self.frames[frame]
......@@ -423,12 +410,12 @@ class Builder(BaseBuilder):
res = self.bind_immutable(name, value)
# 4. Check variable scope and shadowing
if name != '_':
if name != "_":
frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME)
assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames:
logger.warning(
f'Variable `{name}` is declared twice, are you looking for a T.alloc_var?',
f"Variable `{name}` is declared twice, are you looking for a T.alloc_var?",
stack_info=True,
stacklevel=2,
)
......@@ -436,9 +423,9 @@ class Builder(BaseBuilder):
return res
def unwrap_value(self, value):
'''
"""
Unwrap some tilelang objects to get their inner value
'''
"""
value = unwrap_expr(value)
# handle bx, by = tl.Kernel(128, 128), rval is frame
if isinstance(value, tir.frame.IRBuilderFrame):
......@@ -447,11 +434,11 @@ class Builder(BaseBuilder):
return value
def bind_immutable(self, name, value):
'''
"""
Bind an immutable tilelang objects.
The immutability means the result is usually not changed or re-assigned in a python block.
'''
if name == '_':
"""
if name == "_":
# use _tmp to make the generated tir more readable
name = "_tmp"
if isinstance(value, tir.meta_var):
......@@ -459,18 +446,20 @@ class Builder(BaseBuilder):
elif isinstance(value, tir.frame.IRBuilderFrame):
if isinstance(value, tir.frame.ForFrame):
logger.warning(
'Binding a for frame to variable may cause undefined behavior in tilelang.',
"Binding a for frame to variable may cause undefined behavior in tilelang.",
stack_info=True,
stacklevel=2,
)
return self.enter_frame(value)
elif isinstance(value, OutTensor):
arg = tir.arg(name,
arg = tir.arg(
name,
tir.buffer(
shape=value.shape,
dtype=value.dtype,
strides=value.strides,
))
),
)
arg._out_idx = self.out_tensor_cnt
self.out_tensor_cnt += 1
return arg
......@@ -490,8 +479,7 @@ class Builder(BaseBuilder):
def assign_slice(self, lval: Any, sl: slice, value: Any, annot=BaseBuilder.empty):
self.check_continue_break()
if annot is not self.empty:
logger.warning(
"Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2)
logger.warning("Type annotation in slice assignment has no effect", stack_info=True, stacklevel=2)
if isinstance(lval, Buffer):
tir.buffer_store(lval, value, sl)
else:
......@@ -521,11 +509,11 @@ class Builder(BaseBuilder):
left = unwrap_cond(left)
if isinstance(left, PrimExpr):
with self.with_frame(BoolOpFrame()):
if op == 'And':
if op == "And":
return tir.And(left, right())
if op == 'Or':
if op == "Or":
return tir.Or(left, right())
if op == 'Not':
if op == "Not":
return tir.Not(left)
raise RuntimeError(f"Unsupported boolean operator: {op}")
else:
......@@ -557,7 +545,7 @@ class Builder(BaseBuilder):
"You should allocate a var before the control flow, assign value inside the blocks, \n"
"and return the var after the control flow. i.e.\n"
"```\n"
"@T.macro\n" \
"@T.macro\n"
"def my_macro(cond):\n"
" a = T.alloc_var(T.float16)\n"
" if cond:\n"
......@@ -570,14 +558,12 @@ class Builder(BaseBuilder):
if not isinstance(value, tuple):
value = (value,)
for v in value:
if not isinstance(v, Buffer) or not hasattr(v, '_out_idx'):
raise RuntimeError(
f'Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})'
)
if not isinstance(v, Buffer) or not hasattr(v, "_out_idx"):
raise RuntimeError(f"Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})")
# convert 0, 1, 2 => -3, -2, -1 as the out tensor index
self.out_idx.append(v._out_idx - self.out_tensor_cnt)
if len(self.out_idx) != self.out_tensor_cnt:
raise RuntimeError(f'Not all tensor from `T.empty` are returned, only got {value}')
raise RuntimeError(f"Not all tensor from `T.empty` are returned, only got {value}")
return NotImplemented
def ctx_with(self, ctx):
......@@ -591,7 +577,7 @@ class Builder(BaseBuilder):
self.check_continue_break()
cond = unwrap_cond(cond)
if msg is None:
msg = 'Assertion failed'
msg = "Assertion failed"
if isinstance(cond, PrimExpr):
self.enter_frame(tir.Assert(cond, msg))
elif not cond:
......@@ -611,23 +597,18 @@ class Builder(BaseBuilder):
annot_value = self.macro_arg_annot.get(name, None)
if annot_value is Var or annot_value is Ref:
if annot_value is Var:
logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`')
logger.warning("Use `T.Var` as macro annotations is deprecated, please use `T.Ref`")
if isinstance(value, BufferLoad):
if is_var(value.buffer):
return value.buffer
idx = [self.bind('_', idx) for idx in value.indices]
idx = [self.bind("_", idx) for idx in value.indices]
# indices = self.bind(f'_', value.indices)
return Ref(BufferLoad(value.buffer, indices=idx))
if isinstance(value, BufferRegion):
region = [
Range(
self.bind('_', x.begin),
end=self.bind('_', x.end) if x.end is not None else None)
for x in value.region
]
region = [Range(self.bind("_", x.begin), end=self.bind("_", x.end) if x.end is not None else None) for x in value.region]
return BufferRegion(value.buffer, region=region)
raise ValueError(
f'To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})'
f"To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})"
)
elif isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
......@@ -652,13 +633,14 @@ class Builder(BaseBuilder):
def override(self, name: str):
from tilelang.language import serial
if name == 'range':
if name == "range":
return serial
raise ValueError(f'Unknown override: {name}')
raise ValueError(f"Unknown override: {name}")
_P = ParamSpec('_P')
_T = TypeVar('_T')
_P = ParamSpec("_P")
_T = TypeVar("_T")
@dataclass
......@@ -683,14 +665,8 @@ class PrimFuncCreater(Generic[_P, _T]):
return res
def __repr__(self):
fmt = pprint.pformat(
{
'annot': self.func_annot.annots,
'ir_gen': self.ir_gen,
'orig_func': self.orig_func
},
indent=2)
return f'{self.__class__.__name__}(\n{fmt}\n)'
fmt = pprint.pformat({"annot": self.func_annot.annots, "ir_gen": self.ir_gen, "orig_func": self.orig_func}, indent=2)
return f"{self.__class__.__name__}(\n{fmt}\n)"
if TYPE_CHECKING:
......@@ -769,8 +745,7 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
def impl(func: Callable[_P, _T]) -> Macro[_P, _T]:
annotations = get_type_hints(func)
return Macro(
name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations)
return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations)
return impl(func) if func is not None else impl
......@@ -779,9 +754,9 @@ from typing import _eval_type
def get_type_hints(func):
annot = getattr(func, '__annotations__', None)
annot = getattr(func, "__annotations__", None)
if annot is None:
raise TypeError(f'Failed to get function type hints, {func} is not a function')
raise TypeError(f"Failed to get function type hints, {func} is not a function")
hints = {}
# Build eval namespaces from function globals plus captured closure variables
# This lets annotations reference symbols like `n`, `h`, or dtype vars
......@@ -808,7 +783,7 @@ def get_type_hints(func):
# ... # empty function, do not use `n`
localns = utils.get_func_nonlocals(func)
for name, value in annot.items():
if name == 'return':
if name == "return":
continue
if isinstance(value, tvm.DataType):
hints[name] = value
......@@ -821,7 +796,7 @@ def get_type_hints(func):
# typing see: T.float32 is str('float32'), and there is no object named `flaot32` and give a NameError
# here we manually interpret it to return T.float32 object
try:
_, v = value.split('.', maxsplit=1)
_, v = value.split(".", maxsplit=1)
except ValueError:
v = value
if v in dt._all_dtypes:
......@@ -837,9 +812,7 @@ def get_type_hints(func):
return hints
def prim_func(func: Callable[_P, _T] = None,
*,
generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]:
def prim_func(func: Callable[_P, _T] = None, *, generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]:
"""
Decorator to create a primitive function (PrimFunc) for TileLang IR generation.
This decorator transforms a Python function into a TileLang primitive function by analyzing
......@@ -903,7 +876,8 @@ def prim_func(func: Callable[_P, _T] = None,
raise ValueError(
f"Cannot create PrimFunc for `{func.__name__}`, some arguments are not compile-time known, \n"
f"Annotations:\n{func_annot.annots}"
f"Unknown Args: {unknown_args}")
f"Unknown Args: {unknown_args}"
)
return prim_func_generator
return impl(func) if func is not None else impl
......@@ -6,14 +6,12 @@ from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
import numpy as np
_T = TypeVar('_T')
_T = TypeVar("_T")
if TYPE_CHECKING:
class dtype(Generic[_T]):
def torch(self) -> torch.dtype:
...
def torch(self) -> torch.dtype: ...
else:
dtype = tvm.DataType
......@@ -21,53 +19,53 @@ else:
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
_PYTHON_DTYPE_TO_STR = {
bool: 'bool',
int: 'int32',
float: 'float32',
bool: "bool",
int: "int32",
float: "float32",
}
_NUMPY_DTYPE_TO_STR = {
np.bool_: 'bool',
np.short: 'int16',
np.int_: 'int64',
np.longlong: 'int64',
np.half: 'float16',
np.double: 'float64',
np.int8: 'int8',
np.int16: 'int16',
np.int32: 'int32',
np.int64: 'int64',
np.uint8: 'uint8',
np.uint16: 'uint16',
np.uint32: 'uint32',
np.uint64: 'uint64',
np.float16: 'float16',
np.float32: 'float32',
np.float64: 'float64',
np.bool_: "bool",
np.short: "int16",
np.int_: "int64",
np.longlong: "int64",
np.half: "float16",
np.double: "float64",
np.int8: "int8",
np.int16: "int16",
np.int32: "int32",
np.int64: "int64",
np.uint8: "uint8",
np.uint16: "uint16",
np.uint32: "uint32",
np.uint64: "uint64",
np.float16: "float16",
np.float32: "float32",
np.float64: "float64",
}
_NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()})
_TORCH_DTYPE_TO_STR = {
torch.bool: 'bool',
torch.short: 'int16',
torch.int: 'int32',
torch.long: 'int64',
torch.half: 'float16',
torch.float: 'float32',
torch.double: 'float64',
torch.int8: 'int8',
torch.int16: 'int16',
torch.int32: 'int32',
torch.int64: 'int64',
torch.uint8: 'uint8',
torch.uint16: 'uint16',
torch.uint32: 'uint32',
torch.uint64: 'uint64',
torch.float16: 'float16',
torch.float32: 'float32',
torch.float64: 'float64',
torch.bfloat16: 'bfloat16',
torch.bool: "bool",
torch.short: "int16",
torch.int: "int32",
torch.long: "int64",
torch.half: "float16",
torch.float: "float32",
torch.double: "float64",
torch.int8: "int8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
torch.uint8: "uint8",
torch.uint16: "uint16",
torch.uint32: "uint32",
torch.uint64: "uint64",
torch.float16: "float16",
torch.float32: "float32",
torch.float64: "float64",
torch.bfloat16: "bfloat16",
}
# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()}
......@@ -77,24 +75,24 @@ _TORCH_DTYPE_TO_STR = {
_DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR}
_STR_TO_TVM_DTYPE_CALL = {
'bool': 'Boolean',
'int8': 'Int8',
'int32': 'Int32',
'int64': 'Int64',
'uint8': 'UInt8',
'uint16': 'UInt16',
'uint32': 'UInt32',
'uint64': 'UInt64',
'float16': 'Float16',
'float32': 'Float32',
'float64': 'Float64',
'bfloat16': 'BFloat16',
'float8_e4m3': 'Float8E4M3',
'float8_e4m3fn': 'Float8E4M3FN',
'float8_e4m3fnuz': 'Float8E4M3FNUZ',
'float8_e5m2': 'Float8E5M2',
'float8_e5m2fnuz': 'Float8E5M2FNUZ',
'float8_e8m0fnu': 'Float8E8M0FNU'
"bool": "Boolean",
"int8": "Int8",
"int32": "Int32",
"int64": "Int64",
"uint8": "UInt8",
"uint16": "UInt16",
"uint32": "UInt32",
"uint64": "UInt64",
"float16": "Float16",
"float32": "Float32",
"float64": "Float64",
"bfloat16": "BFloat16",
"float8_e4m3": "Float8E4M3",
"float8_e4m3fn": "Float8E4M3FN",
"float8_e4m3fnuz": "Float8E4M3FNUZ",
"float8_e5m2": "Float8E5M2",
"float8_e5m2fnuz": "Float8E5M2FNUZ",
"float8_e8m0fnu": "Float8E8M0FNU",
}
int_ = int
......@@ -108,23 +106,24 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var
call = getattr(tb_ffi, attr, None)
return call(expr, is_size_var)
# try to construct the ffi call
if self.startswith('uint'):
val = 'UInt' + self[4:]
elif self.startswith('int'):
val = 'Int' + self[3:]
elif self.startswith('float'):
val = 'Float' + self[5:]
elif self.startswith('bfloat'):
val = 'BFloat' + self[6:]
if self.startswith("uint"):
val = "UInt" + self[4:]
elif self.startswith("int"):
val = "Int" + self[3:]
elif self.startswith("float"):
val = "Float" + self[5:]
elif self.startswith("bfloat"):
val = "BFloat" + self[6:]
else:
raise TypeError(f'Invalid type {self}')
if '_' in val:
first, second = val.split('_', maxsplit=1)
raise TypeError(f"Invalid type {self}")
if "_" in val:
first, second = val.split("_", maxsplit=1)
val = first + second.upper()
call = getattr(tb_ffi, val, None)
if call is None:
raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n"
f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`")
raise TypeError(
f"Convert to datatype `{self}` is not supported by tvm\ncalling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`"
)
return call(expr, is_size_var)
......@@ -152,7 +151,6 @@ def get_tvm_dtype(value: AnyDType) -> dtype:
if TYPE_CHECKING:
# yapf: disable
class bool(dtype): ...
class short(dtype): ...
......@@ -319,336 +317,336 @@ if TYPE_CHECKING:
# yapf: enable
else:
bool = dtype('bool')
short = dtype('int16')
int = dtype('int32')
long = dtype('int64')
half = dtype('float16')
float = dtype('float32')
double = dtype('float64')
int8 = dtype('int8')
int16 = dtype('int16')
int32 = dtype('int32')
int64 = dtype('int64')
int8x2 = dtype('int8x2')
int16x2 = dtype('int16x2')
int32x2 = dtype('int32x2')
int64x2 = dtype('int64x2')
int8x4 = dtype('int8x4')
int16x4 = dtype('int16x4')
int32x4 = dtype('int32x4')
int64x4 = dtype('int64x4')
int8x8 = dtype('int8x8')
int16x8 = dtype('int16x8')
int32x8 = dtype('int32x8')
int64x8 = dtype('int64x8')
int8x16 = dtype('int8x16')
int16x16 = dtype('int16x16')
int32x16 = dtype('int32x16')
int64x16 = dtype('int64x16')
int8x32 = dtype('int8x32')
int16x32 = dtype('int16x32')
int32x32 = dtype('int32x32')
int64x32 = dtype('int64x32')
int8x64 = dtype('int8x64')
int16x64 = dtype('int16x64')
int32x64 = dtype('int32x64')
int64x64 = dtype('int64x64')
uint8 = dtype('uint8')
uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
uint8x2 = dtype('uint8x2')
uint16x2 = dtype('uint16x2')
uint32x2 = dtype('uint32x2')
uint64x2 = dtype('uint64x2')
uint8x4 = dtype('uint8x4')
uint16x4 = dtype('uint16x4')
uint32x4 = dtype('uint32x4')
uint64x4 = dtype('uint64x4')
uint8x8 = dtype('uint8x8')
uint16x8 = dtype('uint16x8')
uint32x8 = dtype('uint32x8')
uint64x8 = dtype('uint64x8')
uint8x16 = dtype('uint8x16')
uint16x16 = dtype('uint16x16')
uint32x16 = dtype('uint32x16')
uint64x16 = dtype('uint64x16')
uint8x32 = dtype('uint8x32')
uint16x32 = dtype('uint16x32')
uint32x32 = dtype('uint32x32')
uint64x32 = dtype('uint64x32')
uint8x64 = dtype('uint8x64')
uint16x64 = dtype('uint16x64')
uint32x64 = dtype('uint32x64')
uint64x64 = dtype('uint64x64')
float16 = dtype('float16')
float32 = dtype('float32')
float64 = dtype('float64')
float16x2 = dtype('float16x2')
float32x2 = dtype('float32x2')
float64x2 = dtype('float64x2')
float16x4 = dtype('float16x4')
float32x4 = dtype('float32x4')
float64x4 = dtype('float64x4')
float16x8 = dtype('float16x8')
float32x8 = dtype('float32x8')
float64x8 = dtype('float64x8')
float16x16 = dtype('float16x16')
float32x16 = dtype('float32x16')
float64x16 = dtype('float64x16')
float16x32 = dtype('float16x32')
float32x32 = dtype('float32x32')
float64x32 = dtype('float64x32')
float16x64 = dtype('float16x64')
float32x64 = dtype('float32x64')
float64x64 = dtype('float64x64')
float8_e3m4 = dtype('float8_e3m4')
float8_e3m4x2 = dtype('float8_e3m4x2')
float8_e3m4x4 = dtype('float8_e3m4x4')
float8_e3m4x8 = dtype('float8_e3m4x8')
float8_e3m4x16 = dtype('float8_e3m4x16')
float8_e3m4x32 = dtype('float8_e3m4x32')
float8_e3m4x64 = dtype('float8_e3m4x64')
float8_e4m3 = dtype('float8_e4m3')
float8_e4m3x2 = dtype('float8_e4m3x2')
float8_e4m3x4 = dtype('float8_e4m3x4')
float8_e4m3x8 = dtype('float8_e4m3x8')
float8_e4m3x16 = dtype('float8_e4m3x16')
float8_e4m3x32 = dtype('float8_e4m3x32')
float8_e4m3x64 = dtype('float8_e4m3x64')
float8_e4m3b11fnuz = dtype('float8_e4m3b11fnuz')
float8_e4m3b11fnuzx2 = dtype('float8_e4m3b11fnuzx2')
float8_e4m3b11fnuzx4 = dtype('float8_e4m3b11fnuzx4')
float8_e4m3b11fnuzx8 = dtype('float8_e4m3b11fnuzx8')
float8_e4m3b11fnuzx16 = dtype('float8_e4m3b11fnuzx16')
float8_e4m3b11fnuzx32 = dtype('float8_e4m3b11fnuzx32')
float8_e4m3b11fnuzx64 = dtype('float8_e4m3b11fnuzx64')
float8_e4m3fn = dtype('float8_e4m3fn')
float8_e4m3fnx2 = dtype('float8_e4m3fnx2')
float8_e4m3fnx4 = dtype('float8_e4m3fnx4')
float8_e4m3fnx8 = dtype('float8_e4m3fnx8')
float8_e4m3fnx16 = dtype('float8_e4m3fnx16')
float8_e4m3fnx32 = dtype('float8_e4m3fnx32')
float8_e4m3fnx64 = dtype('float8_e4m3fnx64')
float8_e4m3fnuz = dtype('float8_e4m3fnuz')
float8_e4m3fnuzx2 = dtype('float8_e4m3fnuzx2')
float8_e4m3fnuzx4 = dtype('float8_e4m3fnuzx4')
float8_e4m3fnuzx8 = dtype('float8_e4m3fnuzx8')
float8_e4m3fnuzx16 = dtype('float8_e4m3fnuzx16')
float8_e4m3fnuzx32 = dtype('float8_e4m3fnuzx32')
float8_e4m3fnuzx64 = dtype('float8_e4m3fnuzx64')
float8_e5m2 = dtype('float8_e5m2')
float8_e5m2x2 = dtype('float8_e5m2x2')
float8_e5m2x4 = dtype('float8_e5m2x4')
float8_e5m2x8 = dtype('float8_e5m2x8')
float8_e5m2x16 = dtype('float8_e5m2x16')
float8_e5m2x32 = dtype('float8_e5m2x32')
float8_e5m2x64 = dtype('float8_e5m2x64')
float8_e5m2fnuz = dtype('float8_e5m2fnuz')
float8_e5m2fnuzx2 = dtype('float8_e5m2fnuzx2')
float8_e5m2fnuzx4 = dtype('float8_e5m2fnuzx4')
float8_e5m2fnuzx8 = dtype('float8_e5m2fnuzx8')
float8_e5m2fnuzx16 = dtype('float8_e5m2fnuzx16')
float8_e5m2fnuzx32 = dtype('float8_e5m2fnuzx32')
float8_e5m2fnuzx64 = dtype('float8_e5m2fnuzx64')
float8_e8m0fnu = dtype('float8_e8m0fnu')
float8_e8m0fnux2 = dtype('float8_e8m0fnux2')
float8_e8m0fnux4 = dtype('float8_e8m0fnux4')
float8_e8m0fnux8 = dtype('float8_e8m0fnux8')
float8_e8m0fnux16 = dtype('float8_e8m0fnux16')
float8_e8m0fnux32 = dtype('float8_e8m0fnux32')
float8_e8m0fnux64 = dtype('float8_e8m0fnux64')
float6_e2m3fn = dtype('float6_e2m3fn')
float6_e2m3fnx2 = dtype('float6_e2m3fnx2')
float6_e2m3fnx4 = dtype('float6_e2m3fnx4')
float6_e2m3fnx8 = dtype('float6_e2m3fnx8')
float6_e2m3fnx16 = dtype('float6_e2m3fnx16')
float6_e2m3fnx32 = dtype('float6_e2m3fnx32')
float6_e2m3fnx64 = dtype('float6_e2m3fnx64')
float6_e3m2fn = dtype('float6_e3m2fn')
float6_e3m2fnx2 = dtype('float6_e3m2fnx2')
float6_e3m2fnx4 = dtype('float6_e3m2fnx4')
float6_e3m2fnx8 = dtype('float6_e3m2fnx8')
float6_e3m2fnx16 = dtype('float6_e3m2fnx16')
float6_e3m2fnx32 = dtype('float6_e3m2fnx32')
float6_e3m2fnx64 = dtype('float6_e3m2fnx64')
float4_e2m1fn = dtype('float4_e2m1fn')
float4_e2m1fnx2 = dtype('float4_e2m1fnx2')
float4_e2m1fnx4 = dtype('float4_e2m1fnx4')
float4_e2m1fnx8 = dtype('float4_e2m1fnx8')
float4_e2m1fnx16 = dtype('float4_e2m1fnx16')
float4_e2m1fnx32 = dtype('float4_e2m1fnx32')
float4_e2m1fnx64 = dtype('float4_e2m1fnx64')
bfloat16 = dtype('bfloat16')
bool = dtype("bool")
short = dtype("int16")
int = dtype("int32")
long = dtype("int64")
half = dtype("float16")
float = dtype("float32")
double = dtype("float64")
int8 = dtype("int8")
int16 = dtype("int16")
int32 = dtype("int32")
int64 = dtype("int64")
int8x2 = dtype("int8x2")
int16x2 = dtype("int16x2")
int32x2 = dtype("int32x2")
int64x2 = dtype("int64x2")
int8x4 = dtype("int8x4")
int16x4 = dtype("int16x4")
int32x4 = dtype("int32x4")
int64x4 = dtype("int64x4")
int8x8 = dtype("int8x8")
int16x8 = dtype("int16x8")
int32x8 = dtype("int32x8")
int64x8 = dtype("int64x8")
int8x16 = dtype("int8x16")
int16x16 = dtype("int16x16")
int32x16 = dtype("int32x16")
int64x16 = dtype("int64x16")
int8x32 = dtype("int8x32")
int16x32 = dtype("int16x32")
int32x32 = dtype("int32x32")
int64x32 = dtype("int64x32")
int8x64 = dtype("int8x64")
int16x64 = dtype("int16x64")
int32x64 = dtype("int32x64")
int64x64 = dtype("int64x64")
uint8 = dtype("uint8")
uint16 = dtype("uint16")
uint32 = dtype("uint32")
uint64 = dtype("uint64")
uint8x2 = dtype("uint8x2")
uint16x2 = dtype("uint16x2")
uint32x2 = dtype("uint32x2")
uint64x2 = dtype("uint64x2")
uint8x4 = dtype("uint8x4")
uint16x4 = dtype("uint16x4")
uint32x4 = dtype("uint32x4")
uint64x4 = dtype("uint64x4")
uint8x8 = dtype("uint8x8")
uint16x8 = dtype("uint16x8")
uint32x8 = dtype("uint32x8")
uint64x8 = dtype("uint64x8")
uint8x16 = dtype("uint8x16")
uint16x16 = dtype("uint16x16")
uint32x16 = dtype("uint32x16")
uint64x16 = dtype("uint64x16")
uint8x32 = dtype("uint8x32")
uint16x32 = dtype("uint16x32")
uint32x32 = dtype("uint32x32")
uint64x32 = dtype("uint64x32")
uint8x64 = dtype("uint8x64")
uint16x64 = dtype("uint16x64")
uint32x64 = dtype("uint32x64")
uint64x64 = dtype("uint64x64")
float16 = dtype("float16")
float32 = dtype("float32")
float64 = dtype("float64")
float16x2 = dtype("float16x2")
float32x2 = dtype("float32x2")
float64x2 = dtype("float64x2")
float16x4 = dtype("float16x4")
float32x4 = dtype("float32x4")
float64x4 = dtype("float64x4")
float16x8 = dtype("float16x8")
float32x8 = dtype("float32x8")
float64x8 = dtype("float64x8")
float16x16 = dtype("float16x16")
float32x16 = dtype("float32x16")
float64x16 = dtype("float64x16")
float16x32 = dtype("float16x32")
float32x32 = dtype("float32x32")
float64x32 = dtype("float64x32")
float16x64 = dtype("float16x64")
float32x64 = dtype("float32x64")
float64x64 = dtype("float64x64")
float8_e3m4 = dtype("float8_e3m4")
float8_e3m4x2 = dtype("float8_e3m4x2")
float8_e3m4x4 = dtype("float8_e3m4x4")
float8_e3m4x8 = dtype("float8_e3m4x8")
float8_e3m4x16 = dtype("float8_e3m4x16")
float8_e3m4x32 = dtype("float8_e3m4x32")
float8_e3m4x64 = dtype("float8_e3m4x64")
float8_e4m3 = dtype("float8_e4m3")
float8_e4m3x2 = dtype("float8_e4m3x2")
float8_e4m3x4 = dtype("float8_e4m3x4")
float8_e4m3x8 = dtype("float8_e4m3x8")
float8_e4m3x16 = dtype("float8_e4m3x16")
float8_e4m3x32 = dtype("float8_e4m3x32")
float8_e4m3x64 = dtype("float8_e4m3x64")
float8_e4m3b11fnuz = dtype("float8_e4m3b11fnuz")
float8_e4m3b11fnuzx2 = dtype("float8_e4m3b11fnuzx2")
float8_e4m3b11fnuzx4 = dtype("float8_e4m3b11fnuzx4")
float8_e4m3b11fnuzx8 = dtype("float8_e4m3b11fnuzx8")
float8_e4m3b11fnuzx16 = dtype("float8_e4m3b11fnuzx16")
float8_e4m3b11fnuzx32 = dtype("float8_e4m3b11fnuzx32")
float8_e4m3b11fnuzx64 = dtype("float8_e4m3b11fnuzx64")
float8_e4m3fn = dtype("float8_e4m3fn")
float8_e4m3fnx2 = dtype("float8_e4m3fnx2")
float8_e4m3fnx4 = dtype("float8_e4m3fnx4")
float8_e4m3fnx8 = dtype("float8_e4m3fnx8")
float8_e4m3fnx16 = dtype("float8_e4m3fnx16")
float8_e4m3fnx32 = dtype("float8_e4m3fnx32")
float8_e4m3fnx64 = dtype("float8_e4m3fnx64")
float8_e4m3fnuz = dtype("float8_e4m3fnuz")
float8_e4m3fnuzx2 = dtype("float8_e4m3fnuzx2")
float8_e4m3fnuzx4 = dtype("float8_e4m3fnuzx4")
float8_e4m3fnuzx8 = dtype("float8_e4m3fnuzx8")
float8_e4m3fnuzx16 = dtype("float8_e4m3fnuzx16")
float8_e4m3fnuzx32 = dtype("float8_e4m3fnuzx32")
float8_e4m3fnuzx64 = dtype("float8_e4m3fnuzx64")
float8_e5m2 = dtype("float8_e5m2")
float8_e5m2x2 = dtype("float8_e5m2x2")
float8_e5m2x4 = dtype("float8_e5m2x4")
float8_e5m2x8 = dtype("float8_e5m2x8")
float8_e5m2x16 = dtype("float8_e5m2x16")
float8_e5m2x32 = dtype("float8_e5m2x32")
float8_e5m2x64 = dtype("float8_e5m2x64")
float8_e5m2fnuz = dtype("float8_e5m2fnuz")
float8_e5m2fnuzx2 = dtype("float8_e5m2fnuzx2")
float8_e5m2fnuzx4 = dtype("float8_e5m2fnuzx4")
float8_e5m2fnuzx8 = dtype("float8_e5m2fnuzx8")
float8_e5m2fnuzx16 = dtype("float8_e5m2fnuzx16")
float8_e5m2fnuzx32 = dtype("float8_e5m2fnuzx32")
float8_e5m2fnuzx64 = dtype("float8_e5m2fnuzx64")
float8_e8m0fnu = dtype("float8_e8m0fnu")
float8_e8m0fnux2 = dtype("float8_e8m0fnux2")
float8_e8m0fnux4 = dtype("float8_e8m0fnux4")
float8_e8m0fnux8 = dtype("float8_e8m0fnux8")
float8_e8m0fnux16 = dtype("float8_e8m0fnux16")
float8_e8m0fnux32 = dtype("float8_e8m0fnux32")
float8_e8m0fnux64 = dtype("float8_e8m0fnux64")
float6_e2m3fn = dtype("float6_e2m3fn")
float6_e2m3fnx2 = dtype("float6_e2m3fnx2")
float6_e2m3fnx4 = dtype("float6_e2m3fnx4")
float6_e2m3fnx8 = dtype("float6_e2m3fnx8")
float6_e2m3fnx16 = dtype("float6_e2m3fnx16")
float6_e2m3fnx32 = dtype("float6_e2m3fnx32")
float6_e2m3fnx64 = dtype("float6_e2m3fnx64")
float6_e3m2fn = dtype("float6_e3m2fn")
float6_e3m2fnx2 = dtype("float6_e3m2fnx2")
float6_e3m2fnx4 = dtype("float6_e3m2fnx4")
float6_e3m2fnx8 = dtype("float6_e3m2fnx8")
float6_e3m2fnx16 = dtype("float6_e3m2fnx16")
float6_e3m2fnx32 = dtype("float6_e3m2fnx32")
float6_e3m2fnx64 = dtype("float6_e3m2fnx64")
float4_e2m1fn = dtype("float4_e2m1fn")
float4_e2m1fnx2 = dtype("float4_e2m1fnx2")
float4_e2m1fnx4 = dtype("float4_e2m1fnx4")
float4_e2m1fnx8 = dtype("float4_e2m1fnx8")
float4_e2m1fnx16 = dtype("float4_e2m1fnx16")
float4_e2m1fnx32 = dtype("float4_e2m1fnx32")
float4_e2m1fnx64 = dtype("float4_e2m1fnx64")
bfloat16 = dtype("bfloat16")
_all_dtypes = {
'bool',
'short',
'int',
'long',
'half',
'float',
'double',
'int8',
'int16',
'int32',
'int64',
'int8x2',
'int16x2',
'int32x2',
'int64x2',
'int8x4',
'int16x4',
'int32x4',
'int64x4',
'int8x8',
'int16x8',
'int32x8',
'int64x8',
'int8x16',
'int16x16',
'int32x16',
'int64x16',
'int8x32',
'int16x32',
'int32x32',
'int64x32',
'int8x64',
'int16x64',
'int32x64',
'int64x64',
'uint8',
'uint16',
'uint32',
'uint64',
'uint8x2',
'uint16x2',
'uint32x2',
'uint64x2',
'uint8x4',
'uint16x4',
'uint32x4',
'uint64x4',
'uint8x8',
'uint16x8',
'uint32x8',
'uint64x8',
'uint8x16',
'uint16x16',
'uint32x16',
'uint64x16',
'uint8x32',
'uint16x32',
'uint32x32',
'uint64x32',
'uint8x64',
'uint16x64',
'uint32x64',
'uint64x64',
'float16',
'float32',
'float64',
'float16x2',
'float32x2',
'float64x2',
'float16x4',
'float32x4',
'float64x4',
'float16x8',
'float32x8',
'float64x8',
'float16x16',
'float32x16',
'float64x16',
'float16x32',
'float32x32',
'float64x32',
'float16x64',
'float32x64',
'float64x64',
'float8_e3m4',
'float8_e3m4x2',
'float8_e3m4x4',
'float8_e3m4x8',
'float8_e3m4x16',
'float8_e3m4x32',
'float8_e3m4x64',
'float8_e4m3',
'float8_e4m3x2',
'float8_e4m3x4',
'float8_e4m3x8',
'float8_e4m3x16',
'float8_e4m3x32',
'float8_e4m3x64',
'float8_e4m3b11fnuz',
'float8_e4m3b11fnuzx2',
'float8_e4m3b11fnuzx4',
'float8_e4m3b11fnuzx8',
'float8_e4m3b11fnuzx16',
'float8_e4m3b11fnuzx32',
'float8_e4m3b11fnuzx64',
'float8_e4m3fn',
'float8_e4m3fnx2',
'float8_e4m3fnx4',
'float8_e4m3fnx8',
'float8_e4m3fnx16',
'float8_e4m3fnx32',
'float8_e4m3fnx64',
'float8_e4m3fnuz',
'float8_e4m3fnuzx2',
'float8_e4m3fnuzx4',
'float8_e4m3fnuzx8',
'float8_e4m3fnuzx16',
'float8_e4m3fnuzx32',
'float8_e4m3fnuzx64',
'float8_e5m2',
'float8_e5m2x2',
'float8_e5m2x4',
'float8_e5m2x8',
'float8_e5m2x16',
'float8_e5m2x32',
'float8_e5m2x64',
'float8_e5m2fnuz',
'float8_e5m2fnuzx2',
'float8_e5m2fnuzx4',
'float8_e5m2fnuzx8',
'float8_e5m2fnuzx16',
'float8_e5m2fnuzx32',
'float8_e5m2fnuzx64',
'float8_e8m0fnu',
'float8_e8m0fnux2',
'float8_e8m0fnux4',
'float8_e8m0fnux8',
'float8_e8m0fnux16',
'float8_e8m0fnux32',
'float8_e8m0fnux64',
'float6_e2m3fn',
'float6_e2m3fnx2',
'float6_e2m3fnx4',
'float6_e2m3fnx8',
'float6_e2m3fnx16',
'float6_e2m3fnx32',
'float6_e2m3fnx64',
'float6_e3m2fn',
'float6_e3m2fnx2',
'float6_e3m2fnx4',
'float6_e3m2fnx8',
'float6_e3m2fnx16',
'float6_e3m2fnx32',
'float6_e3m2fnx64',
'float4_e2m1fn',
'float4_e2m1fnx2',
'float4_e2m1fnx4',
'float4_e2m1fnx8',
'float4_e2m1fnx16',
'float4_e2m1fnx32',
'float4_e2m1fnx64',
'bfloat16',
"bool",
"short",
"int",
"long",
"half",
"float",
"double",
"int8",
"int16",
"int32",
"int64",
"int8x2",
"int16x2",
"int32x2",
"int64x2",
"int8x4",
"int16x4",
"int32x4",
"int64x4",
"int8x8",
"int16x8",
"int32x8",
"int64x8",
"int8x16",
"int16x16",
"int32x16",
"int64x16",
"int8x32",
"int16x32",
"int32x32",
"int64x32",
"int8x64",
"int16x64",
"int32x64",
"int64x64",
"uint8",
"uint16",
"uint32",
"uint64",
"uint8x2",
"uint16x2",
"uint32x2",
"uint64x2",
"uint8x4",
"uint16x4",
"uint32x4",
"uint64x4",
"uint8x8",
"uint16x8",
"uint32x8",
"uint64x8",
"uint8x16",
"uint16x16",
"uint32x16",
"uint64x16",
"uint8x32",
"uint16x32",
"uint32x32",
"uint64x32",
"uint8x64",
"uint16x64",
"uint32x64",
"uint64x64",
"float16",
"float32",
"float64",
"float16x2",
"float32x2",
"float64x2",
"float16x4",
"float32x4",
"float64x4",
"float16x8",
"float32x8",
"float64x8",
"float16x16",
"float32x16",
"float64x16",
"float16x32",
"float32x32",
"float64x32",
"float16x64",
"float32x64",
"float64x64",
"float8_e3m4",
"float8_e3m4x2",
"float8_e3m4x4",
"float8_e3m4x8",
"float8_e3m4x16",
"float8_e3m4x32",
"float8_e3m4x64",
"float8_e4m3",
"float8_e4m3x2",
"float8_e4m3x4",
"float8_e4m3x8",
"float8_e4m3x16",
"float8_e4m3x32",
"float8_e4m3x64",
"float8_e4m3b11fnuz",
"float8_e4m3b11fnuzx2",
"float8_e4m3b11fnuzx4",
"float8_e4m3b11fnuzx8",
"float8_e4m3b11fnuzx16",
"float8_e4m3b11fnuzx32",
"float8_e4m3b11fnuzx64",
"float8_e4m3fn",
"float8_e4m3fnx2",
"float8_e4m3fnx4",
"float8_e4m3fnx8",
"float8_e4m3fnx16",
"float8_e4m3fnx32",
"float8_e4m3fnx64",
"float8_e4m3fnuz",
"float8_e4m3fnuzx2",
"float8_e4m3fnuzx4",
"float8_e4m3fnuzx8",
"float8_e4m3fnuzx16",
"float8_e4m3fnuzx32",
"float8_e4m3fnuzx64",
"float8_e5m2",
"float8_e5m2x2",
"float8_e5m2x4",
"float8_e5m2x8",
"float8_e5m2x16",
"float8_e5m2x32",
"float8_e5m2x64",
"float8_e5m2fnuz",
"float8_e5m2fnuzx2",
"float8_e5m2fnuzx4",
"float8_e5m2fnuzx8",
"float8_e5m2fnuzx16",
"float8_e5m2fnuzx32",
"float8_e5m2fnuzx64",
"float8_e8m0fnu",
"float8_e8m0fnux2",
"float8_e8m0fnux4",
"float8_e8m0fnux8",
"float8_e8m0fnux16",
"float8_e8m0fnux32",
"float8_e8m0fnux64",
"float6_e2m3fn",
"float6_e2m3fnx2",
"float6_e2m3fnx4",
"float6_e2m3fnx8",
"float6_e2m3fnx16",
"float6_e2m3fnx32",
"float6_e2m3fnx64",
"float6_e3m2fn",
"float6_e3m2fnx2",
"float6_e3m2fnx4",
"float6_e3m2fnx8",
"float6_e3m2fnx16",
"float6_e3m2fnx32",
"float6_e3m2fnx64",
"float4_e2m1fn",
"float4_e2m1fnx2",
"float4_e2m1fnx4",
"float4_e2m1fnx8",
"float4_e2m1fnx16",
"float4_e2m1fnx32",
"float4_e2m1fnx64",
"bfloat16",
}
__all__ = list(_all_dtypes) + [
'dtype',
'AnyDType',
'get_tvm_dtype',
"dtype",
"AnyDType",
"get_tvm_dtype",
]
......@@ -12,11 +12,12 @@ def disk_compile(source, name):
cache_dir = env.TILELANG_CACHE_DIR
if cache_dir is not None:
import os
save_dir = os.path.join(cache_dir, "py-cache")
os.makedirs(save_dir, exist_ok=True)
hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8]
hash_sfx = sha256(source.encode("utf-8")).hexdigest()[:8]
path = os.path.join(save_dir, f"{name}.{hash_sfx}.py")
with open(path, 'w') as f:
with open(path, "w") as f:
f.write(source)
linecache.cache[path] = (len(source), None, source.splitlines(), path)
return compile(source, path, "exec")
......@@ -59,29 +60,26 @@ def get_ast(func: Callable):
filename = inspect.getsourcefile(func) or inspect.getfile(func)
source = inspect.getsource(func)
source = _remove_leading_ident(source)
source = '\n' * (start - 1) + source
source = "\n" * (start - 1) + source
tree = ast.parse(source, filename=filename)
return tree
CompileMethod = Literal['direct', 'disk']
CompileMethod = Literal["direct", "disk"]
def get_compiled_object(source: str | ast.AST,
name: str,
filename: str = None,
globals: dict[str, Any] = None):
def get_compiled_object(source: str | ast.AST, name: str, filename: str = None, globals: dict[str, Any] = None):
if isinstance(source, ast.AST):
assert filename is not None, "filename must be provided when source is an AST"
try:
if isinstance(source, ast.AST):
ast.fix_missing_locations(source)
compiled = compile(source, filename, 'exec')
compiled = compile(source, filename, "exec")
else:
compiled = disk_compile(source, name)
except Exception as e:
source_str = source if isinstance(source, str) else ast.unparse(source)
raise RuntimeError(f'Failed to compile source for {name}, Error: {e}:\n{source_str}') from e
raise RuntimeError(f"Failed to compile source for {name}, Error: {e}:\n{source_str}") from e
locs = {}
exec(compiled, globals, locs)
return locs[name]
......@@ -95,7 +93,6 @@ def construct_strides(shape: tuple[Any, ...], allow_prim_expr: bool = True) -> t
strides.append(stride)
stride *= s
if not allow_prim_expr and isinstance(stride, tir.PrimExpr):
raise ValueError(
"Cannot construct strides with PrimExpr when allow_prim_expr is False.")
raise ValueError("Cannot construct strides with PrimExpr when allow_prim_expr is False.")
strides = tuple(reversed(strides))
return strides
"""The language interface for tl programs."""
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm.ffi import register_object
from tilelang import _ffi_api
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
import tvm
import tvm_ffi
......@@ -20,12 +21,7 @@ class Fragment(Layout):
# Disable the linter warning about not calling super().__init__()
# because this object is created via TVM's FFI constructor mechanism.
# pylint: disable=super-init-not-called
def __init__(self,
shape,
forward_fn=None,
forward_thread_fn=None,
replicate=1,
forward_index_fn=None):
def __init__(self, shape, forward_fn=None, forward_thread_fn=None, replicate=1, forward_index_fn=None):
"""
Initialize the Fragment with iteration variables and optional thread replication.
......@@ -119,10 +115,7 @@ class Fragment(Layout):
"""
return _ffi_api.Fragment_thread_size(self)
def repeat(self,
repeats,
repeat_on_thread: bool = False,
lower_dim_first: bool = True) -> 'Fragment':
def repeat(self, repeats, repeat_on_thread: bool = False, lower_dim_first: bool = True) -> "Fragment":
"""
Returns a new Fragment that repeats the iteration space a given number of times.
......@@ -142,7 +135,7 @@ class Fragment(Layout):
"""
return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first)
def replicate(self, replicate: int) -> 'Fragment':
def replicate(self, replicate: int) -> "Fragment":
"""
Replicate the Fragment across a new thread dimension.
......@@ -158,7 +151,7 @@ class Fragment(Layout):
"""
return _ffi_api.Fragment_replicate(self, replicate)
def condense_rep_var(self) -> 'Fragment':
def condense_rep_var(self) -> "Fragment":
"""
Condense or fold the replicate variable into the existing iteration space.
This operation may be used to reduce dimensionality if the replicate variable
......@@ -190,8 +183,7 @@ class Fragment(Layout):
# The thread dimension (IterVar) is accessed via the `thread` property
forward_thread = self.thread
# Construct an IndexMap to map the provided args into the final thread index
index_map = IndexMap(
initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None)
index_map = IndexMap(initial_indices=forward_vars, final_indices=[forward_thread], inverse_index_map=None)
return index_map.map_indices(indices)
def __repr__(self):
......@@ -206,7 +198,7 @@ class Fragment(Layout):
return self._DebugOutput()
# return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
def is_equal(self, other: 'Fragment') -> bool:
def is_equal(self, other: "Fragment") -> bool:
"""
Check if the current fragment is equal to another fragment.
"""
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm
......@@ -114,8 +115,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"
] and buffer.dtype not in ["uint32", "int32"]:
if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]:
raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
m, k = buffer.shape
......@@ -134,10 +134,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
return T.Layout(buffer.shape, ColumnMajorInterleaved)
def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16",
arch: str | None = None,
**extra_args):
def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", arch: str | None = None, **extra_args):
if arch is None:
arch = nvcc.get_target_compute_version()
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
import tvm_ffi
from tvm.ir import Node, Range
......@@ -9,7 +10,6 @@ from tilelang import _ffi_api
# Register the Layout class as a TVM object under the name "tl.Layout"
@tvm_ffi.register_object("tl.Layout")
class Layout(Node):
def __init__(self, shape, forward_fn):
"""
Initialize a Layout object.
......@@ -114,13 +114,13 @@ class Layout(Node):
index_map = IndexMap(
initial_indices=forward_vars, # The original iteration variables
final_indices=forward_indexes, # The computed forward indices
inverse_index_map=None # No inverse mapping provided at this stage
inverse_index_map=None, # No inverse mapping provided at this stage
)
# Map the provided indices using the constructed index mapping
return index_map.map_indices(indices)
def inverse(self) -> 'Layout':
def inverse(self) -> "Layout":
"""
Compute the inverse of the current layout transformation.
......@@ -131,7 +131,7 @@ class Layout(Node):
"""
return _ffi_api.Layout_inverse(self)
def is_equal(self, other: 'Layout') -> bool:
def is_equal(self, other: "Layout") -> bool:
"""
Check if the current layout is equal to another layout.
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
......@@ -7,9 +8,7 @@ from tvm.tir import Buffer, BufferLoad, BufferRegion
from tilelang import _ffi_api
def _get_buffer_info(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion
) -> tuple[Buffer, list[int], str]:
def _get_buffer_info(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[Buffer, list[int], str]:
"""
Extract buffer, shape, and dtype from Buffer, BufferLoad, or BufferRegion.
......@@ -25,12 +24,10 @@ def _get_buffer_info(
buf = buffer_or_load_or_region.buffer
return buf, buf.shape, buf.dtype
else:
raise TypeError(
f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
def _get_stride_continuous(
buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
def _get_stride_continuous(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> tuple[int, int]:
"""
Get stride (last 2nd dimension) and continuous (last dimension) from Buffer, BufferLoad, or BufferRegion.
......@@ -62,9 +59,7 @@ def _get_element_size(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegi
# Use a stable swizzled layout to ensure consistent memory access patterns.
# Swizzling should be enabled or disabled based on whether TMA (Tensor Memory Access) is applied.
def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
k_major: bool = True,
allow_pad: bool = True):
def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, k_major: bool = True, allow_pad: bool = True):
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
return _ffi_api.make_swizzled_layout(
......@@ -77,9 +72,7 @@ def make_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
# for Volta Intrinsics
def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
is_a: bool = True,
k_inner: bool = True):
def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, is_a: bool = True, k_inner: bool = True):
stride, continuous = _get_stride_continuous(buffer)
return _ffi_api.make_volta_swizzled_layout(
stride,
......@@ -90,9 +83,7 @@ def make_volta_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
# for WGMMA Intrinsics
def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
continuity: int = None,
k_major: bool = True):
def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True):
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
if continuity is None:
......@@ -107,9 +98,7 @@ def make_wgmma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
# for TCGEN05MMA Intrinsics
def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion,
continuity: int = None,
k_major: bool = True):
def make_tcgen05mma_swizzled_layout(buffer: Buffer | BufferLoad | BufferRegion, continuity: int = None, k_major: bool = True):
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
if continuity is None:
......
......@@ -31,6 +31,5 @@ def find_lib_path(name: str, py_ext=False):
if os.path.exists(lib_dll_path) and os.path.isfile(lib_dll_path):
return lib_dll_path
else:
message = (f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" +
"\n".join(TL_LIBS))
message = f"Cannot find libraries: {lib_name}\n" + "List of candidates:\n" + "\n".join(TL_LIBS)
raise RuntimeError(message)
""" bootstrap the primitives module via tile language """
"""bootstrap the primitives module via tile language"""
from .gemm import gemm # noqa: F401
......@@ -3,7 +3,8 @@ from tvm import tir
from tilelang.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.primitives.gemm.gemm_mma import (
GemmPrimitiveMMA,)
GemmPrimitiveMMA,
)
def gemm(
......@@ -20,12 +21,9 @@ def gemm(
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1,
):
assert is_local(A) or is_fragment(A) or is_shared(A), (
f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}")
assert is_local(B) or is_fragment(B) or is_shared(B), (
f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}")
assert is_local(C) or is_fragment(C), (
f"Expected C to be a local, fragment, but got {C.scope()}")
assert is_local(A) or is_fragment(A) or is_shared(A), f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}"
assert is_local(B) or is_fragment(B) or is_shared(B), f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}"
assert is_local(C) or is_fragment(C), f"Expected C to be a local, fragment, but got {C.scope()}"
# TODO(lei): Now we only support Nvidia GPUs
# Must enhance the design to implement runtime lowering
# for different targets (hip mfma for example)
......
......@@ -131,7 +131,7 @@ class GemmWarpPolicy(IntEnum):
# Try to find the best balanced partition
best_m = 1
best_n = 1
best_balance = float('inf')
best_balance = float("inf")
# Try all possible combinations that satisfy the constraints
for m in range(1, min(max_m_warps, num_warps) + 1):
......@@ -202,7 +202,7 @@ class GemmBaseParams:
warp_row_tiles: int | None = None
warp_col_tiles: int | None = None
chunk: int | None = None
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
policy: GemmWarpPolicy = (GemmWarpPolicy.Square,)
k_pack: int = 1
def get_warp_size(self) -> int:
......@@ -267,17 +267,17 @@ class GemmBaseParams:
# Determine whether block partition parameters need to be inferred
require_infer = (
block_row_warps is None or block_col_warps is None or warp_row_tiles is None or
warp_col_tiles is None or chunk is None)
block_row_warps is None or block_col_warps is None or warp_row_tiles is None or warp_col_tiles is None or chunk is None
)
A_shape, B_shape = A.shape, B.shape
if require_infer:
assert (threads is not None), "threads must be provided for auto inference"
assert threads is not None, "threads must be provided for auto inference"
# Auto-inference only supports 2D matrix multiplication
assert (
len(A_shape) == 2 and len(B_shape) == 2
), f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D"
assert len(A_shape) == 2 and len(B_shape) == 2, (
f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D"
)
# Analyze A/B shapes
AM = A_shape[1] if transpose_A else A_shape[0] # M dimension
......@@ -291,8 +291,7 @@ class GemmBaseParams:
num_warps = threads // warp_size
# Infer block partition using a user-specified policy
block_row_warps, block_col_warps = policy.compute_warp_partition(
block_M, block_N, num_warps)
block_row_warps, block_col_warps = policy.compute_warp_partition(block_M, block_N, num_warps)
warp_row_tiles = block_M // block_row_warps
warp_col_tiles = block_N // block_col_warps
chunk = int(AK)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment