Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
"""The language interface for tl programs."""
from __future__ import annotations
from tvm import tir
from tilelang.language import copy, macro, alloc_shared, alloc_fragment
......@@ -30,15 +31,13 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
tir.Call: Handle to the reduction operation
"""
# input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y]
expected_shapes = [
buffer.shape[:dim] + buffer.shape[dim + 1:],
buffer.shape[:dim] + [1] + buffer.shape[dim + 1:]
]
expected_shapes = [buffer.shape[:dim] + buffer.shape[dim + 1 :], buffer.shape[:dim] + [1] + buffer.shape[dim + 1 :]]
if list(out.shape) not in expected_shapes:
expected_shapes_str = ' or '.join(map(str, expected_shapes))
expected_shapes_str = " or ".join(map(str, expected_shapes))
raise ValueError(
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}"
)
@macro
def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
......
......@@ -7,9 +7,7 @@ from tvm.tir.function import PrimFunc
from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Callable | None = None,
private: bool = False,
check_well_formed: bool = False) -> PrimFunc | Callable:
def prim_func(func: Callable | None = None, private: bool = False, check_well_formed: bool = False) -> PrimFunc | Callable:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
......@@ -113,8 +111,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
if len(args) == 1 and inspect.isfunction(args[0]):
return _decorator(args[0])
raise ValueError(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
raise ValueError("Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])")
setattr(macro, "dispatch_token", "tir") # noqa: B010
......@@ -6,10 +6,7 @@ import tilelang.language.tir.op as _tir_op
import functools
def serial(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def serial(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement.
Parameters
......@@ -31,10 +28,7 @@ def serial(start: PrimExpr,
return _ir.serial(start=start, stop=stop, annotations=annotations)
def parallel(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def parallel(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement.
Parameters
......@@ -56,10 +50,7 @@ def parallel(start: PrimExpr,
return _ir.parallel(start=start, stop=stop, annotations=annotations)
def vectorized(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def vectorized(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement.
Parameters
......@@ -81,10 +72,7 @@ def vectorized(start: PrimExpr,
return _ir.vectorized(start=start, stop=stop, annotations=annotations)
def unroll(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
def unroll(start: PrimExpr, stop: PrimExpr = None, *, annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
......@@ -161,7 +149,6 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
def _dtype_forward(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
......@@ -172,7 +159,6 @@ def _dtype_forward(func):
def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
......
from typing import TypeVar, Literal
from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm
_T = TypeVar('_T')
_T = TypeVar("_T")
def abs(x: _T, span: Span | None=None) -> _T: ...
def abs(x: _T, span: Span | None = None) -> _T: ...
def acos(x: _T) -> _T: ...
def acosh(x: _T) -> _T: ...
def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ...
def address_of(buffer_load: BufferLoad, span: Span | None = None) -> PrimExpr: ...
def asin(x: _T) -> _T: ...
def asinh(x: _T) -> _T: ...
def atan(x: _T) -> _T: ...
def atan2(x1: _T, x2: _T) -> _T: ...
def atanh(x: _T) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_not(x: _T, span: Span | None=None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ...
def ceil(x: _T, span: Span | None=None) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None = None) -> _T: ...
def bitwise_not(x: _T, span: Span | None = None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None = None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None = None) -> _T: ...
def ceil(x: _T, span: Span | None = None) -> _T: ...
def clz(x: _T) -> _T: ...
def copysign(x1: _T, x2: _T) -> _T: ...
def cos(x: _T) -> _T: ...
......@@ -25,35 +25,37 @@ def erf(x: _T) -> _T: ...
def exp(x: _T) -> _T: ...
def exp2(x: _T) -> _T: ...
def exp10(x: _T) -> _T: ...
def floor(x: _T, span: Span | None=None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floor(x: _T, span: Span | None = None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None = None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None = None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None = None) -> _T: ...
def fmod(x: _T, y: _T) -> _T: ...
def hypot(x1: _T, x2: _T) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ...
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
def isfinite(x: _T, span: Span | None=None) -> _T: ...
def isinf(x: _T, span: Span | None=None) -> _T: ...
def isnan(x: _T, span: Span | None=None) -> _T: ...
def isnullptr(x: _T, span: Span | None=None) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None = None) -> _T: ...
def infinity(dtype: _T, span: Span | None = None) -> _T: ...
def isfinite(x: _T, span: Span | None = None) -> _T: ...
def isinf(x: _T, span: Span | None = None) -> _T: ...
def isnan(x: _T, span: Span | None = None) -> _T: ...
def isnullptr(x: _T, span: Span | None = None) -> _T: ...
def ldexp(x1: _T, x2: _T) -> _T: ...
def likely(cond: _T, span: Span | None=None) -> _T: ...
def likely(cond: _T, span: Span | None = None) -> _T: ...
def log(x: _T) -> _T: ...
def log1p(x: _T) -> _T: ...
def log2(x: _T) -> _T: ...
def log10(x: _T) -> _T: ...
def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None=None) -> _T: ...
def lookup_param(param_name: str, span: Span | None = None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None = None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None = None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None = None) -> _T: ...
def nextafter(x1: _T, x2: _T) -> _T: ...
def popcount(x: _T) -> _T: ...
def pow(x: _T, y: _T, span: Span | None=None) -> _T: ...
def pow(x: _T, y: _T, span: Span | None = None) -> _T: ...
def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ...
def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ...
def q_multiply_shift_per_axis(
x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm
) -> PrimExpr: ...
def ret(val: _T) -> _T: ...
def round(x: _T, span: Span | None=None) -> _T: ...
def round(x: _T, span: Span | None = None) -> _T: ...
def rsqrt(x: _T) -> _T: ...
def shift_left(x: _T, y: _T, span=None) -> _T: ...
def shift_right(x: _T, y: _T, span=None) -> _T: ...
......@@ -63,14 +65,16 @@ def sinh(x: _T) -> _T: ...
def sqrt(x: _T) -> _T: ...
def tan(x: _T) -> _T: ...
def tanh(x: _T) -> _T: ...
def trunc(x: _T, span: Span | None=None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def trunc(x: _T, span: Span | None = None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None = None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None = None) -> _T: ...
def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ...
def tvm_throw_last_error() -> _T: ...
def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ...
def tvm_stack_make_shape(*args) -> _T: ...
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ...
def tvm_stack_make_array(
data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset
) -> PrimExpr: ...
def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ...
def call_packed(*args, span=None) -> _T: ...
def call_cpacked(*args, span=None) -> _T: ...
......@@ -80,11 +84,47 @@ def tvm_tuple(*value) -> _T: ...
def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
def tvm_thread_invariant(cond: _T) -> _T: ...
def tvm_thread_allreduce(*freduce_args) -> _T: ...
def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_load_matrix_sync(
fragment: Var,
m: IntImm,
n: IntImm,
k: IntImm,
index: PrimExpr,
buffer_ptr: PrimExpr,
stride: PrimExpr,
layout: Literal["row_major", "column_major"],
) -> PrimExpr: ...
def tvm_mma_sync(
fragment_d: Var,
index_d: PrimExpr,
fragment_a: Var,
index_a: PrimExpr,
fragment_b: Var,
index_b: PrimExpr,
fragment_c: Var,
index_c: PrimExpr,
) -> PrimExpr: ...
def tvm_bmma_sync(
fragment_d: Var,
index_d: PrimExpr,
fragment_a: Var,
index_a: PrimExpr,
fragment_b: Var,
index_b: PrimExpr,
fragment_c: Var,
index_c: PrimExpr,
) -> PrimExpr: ...
def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ...
def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_store_matrix_sync(
fragment: Var,
m: IntImm,
n: IntImm,
k: IntImm,
index: PrimExpr,
buffer_ptr: PrimExpr,
stride: PrimExpr,
layout: Literal["row_major", "column_major"],
) -> PrimExpr: ...
def ptx_wait_group(num: int) -> PrimExpr: ...
def ptx_commit_group() -> _T: ...
def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ...
......@@ -93,7 +133,7 @@ def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ...
def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ...
def create_barriers(barrier_count: int) -> PrimExpr: ...
def assume(cond: _T=None) -> _T: ...
def assume(cond: _T = None) -> _T: ...
def undef() -> _T: ...
def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ...
def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ...
......
......@@ -724,8 +724,7 @@ def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
return _tvm_op.tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout)
def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c,
index_c):
def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c):
"""TVM intrinsic for tensor core mma_sync operators
Parameters
......@@ -759,12 +758,10 @@ def tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
fragment_c, index_c)
return _tvm_op.tvm_mma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)
def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c,
index_c):
def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c):
"""TVM intrinsic for tensor core bmma_sync operators
Parameters
......@@ -798,8 +795,7 @@ def tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
call : PrimExpr
The call expression.
"""
return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b,
fragment_c, index_c)
return _tvm_op.tvm_bmma_sync(fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c)
def tvm_fill_fragment(fragment, m, n, k, index, value):
......@@ -1121,7 +1117,6 @@ def ptx_wgmma_rs(
scale_in_a,
scale_in_b,
):
return call_intrin(
dtype,
_tvm_op.Op.get("tl.ptx_wgmma_rs"),
......@@ -1345,8 +1340,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr,
smem_offset)
return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset)
def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
......@@ -1381,8 +1375,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by
return _tvm_op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes)
def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes,
barrier_id):
def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
......@@ -1414,8 +1407,7 @@ def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offse
call : PrimExpr
The call expression.
"""
return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset,
bytes, barrier_id)
return _tvm_op.ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id)
def ptx_commit_group():
......@@ -2951,8 +2943,7 @@ def q_multiply_shift_per_axis(
z : PrimExpr
The result.
"""
return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required,
is_rshift_required)
return _tvm_op.q_multiply_shift_per_axis(x, y, ls, rs, q, is_lshift_required, is_rshift_required)
def shift_left(x, y, span=None):
......@@ -3302,8 +3293,7 @@ def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dt
call : PrimExpr
The call expression.
"""
return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint,
dtype_bits_hint)
return _tvm_op.TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint)
def TVMBackendFreeWorkspace(device_type, device_id, ptr):
......
......@@ -14,23 +14,18 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list
"""Convert a BufferLoad to a tl.region call with explicit extents."""
indices = list(load.indices)
if len(indices) > len(extents):
extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))
] + list(extents)
extents = [tir.IntImm("int32", 1) for _ in range(len(indices) - len(extents))] + list(extents)
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: list[tir.PrimExpr]):
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, extents: list[tir.PrimExpr]):
"""Clamp extents and return a tl.region call."""
mins = [r.min for r in buffer_region.region]
region_extents = [r.extent for r in buffer_region.region]
assert len(region_extents) >= len(extents), (
f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
)
assert len(region_extents) >= len(extents), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
clamped_extents = [
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i]
for i in range(len(region_extents))
tir.min(region_extents[i], extents[i]) if i < len(extents) else region_extents[i] for i in range(len(region_extents))
]
return region(tir.BufferLoad(buffer_region.buffer, mins), access_type, *clamped_extents)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment