Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
from tvm import tir
from typing import Union, Any, Optional
from typing import Any
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
_IS_HIP_AVAILABLE = check_hip_availability()
def _normalize_index_arg(value: Optional[Union[int, PrimExpr]]) -> Optional[PrimExpr]:
def _normalize_index_arg(value: int | PrimExpr | None) -> PrimExpr | None:
"""
Normalize warp sizing arguments so both Python ints and PrimExpr values
are accepted uniformly.
......@@ -183,7 +184,7 @@ def disable_warp_group_reg_alloc():
return no_set_max_nreg()
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
def mbarrier_wait_parity(mbarrier: int | PrimExpr | tir.Call, parity: int | Var):
"""Wait for memory barrier parity condition.
Args:
......@@ -233,7 +234,7 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]):
def mbarrier_arrive(mbarrier: int | PrimExpr | tir.Call):
"""Arrive at memory barrier.
Args:
......@@ -294,7 +295,7 @@ def warpgroup_wait(num_mma: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the logical lane index of the calling thread within a warp.
Parameters
......@@ -319,7 +320,7 @@ def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
......@@ -343,7 +344,7 @@ def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> Prim
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the canonical warp index without synchronizing the warp.
Parameters
......@@ -368,8 +369,8 @@ def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
def get_warp_group_idx(
warp_size: Optional[Union[int, PrimExpr]] = None,
warps_per_group: Optional[Union[int, PrimExpr]] = None,
warp_size: int | PrimExpr | None = None,
warps_per_group: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the canonical warp group index for the calling thread.
......@@ -441,7 +442,7 @@ def wait_wgmma(id: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), id)
def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, Var, None] = None):
def barrier_wait(barrier_id: int | PrimExpr | tir.Call, parity: int | Var | None = None):
"""Wait for a memory barrier to complete.
Args:
......@@ -456,7 +457,7 @@ def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int,
return mbarrier_wait_parity(barrier_id, parity)
def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]):
def barrier_arrive(barrier_id: int | PrimExpr | tir.Call):
"""Arrive at a memory barrier.
Args:
......@@ -466,7 +467,7 @@ def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]):
return mbarrier_arrive(barrier_id)
def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with XOR offset.
Args:
......@@ -483,7 +484,7 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with down offset.
Args:
......@@ -496,7 +497,7 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with up offset.
Args:
......@@ -601,7 +602,7 @@ def loop_break():
return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break"))
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Union, Optional, Literal
from typing import Literal
from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load
from tvm import ir, tir
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
dst: Union[tir.Buffer, tir.BufferLoad],
coalesced_width: Optional[int] = None,
def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
dst: tir.Buffer | tir.BufferLoad,
coalesced_width: int | None = None,
disable_tma: bool = False,
eviction_policy: Optional[Literal["evict_normal", "evict_first", "evict_last"]] = None):
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
"""Copy data between memory regions.
Args:
......@@ -94,8 +95,7 @@ def c2d_im2col(img: tir.Buffer,
stride: int,
dilation: int,
pad: int,
eviction_policy: Optional[Literal["evict_normal", "evict_first",
"evict_last"]] = None):
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
"""Perform im2col transformation for 2D convolution.
Args:
......
"""The language interface for tl programs."""
from __future__ import annotations
import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, op
from typing import List, Union
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
......@@ -36,7 +36,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
return dst
def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape.
Args:
......@@ -49,9 +49,7 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
return T.Tensor(shape, src.dtype, src.data)
def view(src: Buffer,
shape: Union[List[PrimExpr], None] = None,
dtype: Union[str, None] = None) -> Buffer:
def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer:
"""
Return a Tensor view of the input buffer with an optional new shape and dtype.
......
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from typing import Union
def gemm_sp(
A_sparse: Union[tir.Buffer, tir.Var],
E: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
A_sparse: tir.Buffer | tir.Var,
E: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
......@@ -42,7 +42,7 @@ def gemm_sp(
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
......
"""The language interface for tl programs."""
from __future__ import annotations
from tvm import tir
from typing import Union
from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load
def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value.
Args:
......@@ -21,7 +21,7 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
def clear(buffer: Union[tir.Buffer, tir.Var]):
def clear(buffer: tir.Buffer | tir.Var):
"""Clear a buffer by filling it with zeros.
Args:
......
"""Override the LetFrame to print a message when entering the frame."""
from __future__ import annotations
from tvm.ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
......@@ -6,7 +7,6 @@ from tvm.ir import Range
from tvm import DataType
from tvm.script.ir_builder.tir.frame import TIRFrame
from collections import deque
from typing import Optional
import threading
......@@ -150,7 +150,7 @@ class LetFrame(TIRFrame):
super().__exit__(ptype, value, trace)
@classmethod
def Current(cls) -> "LetFrame":
def Current(cls) -> LetFrame:
"""Get the current (topmost) let frame.
Returns:
......@@ -198,7 +198,7 @@ def has_let_value(var: Var) -> bool:
return _get_let_stack().has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]:
def get_let_value(var: Var) -> PrimExpr | None:
"""Get the value bound to a variable in the current let frame stack.
Args:
......
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from typing import Union, List, Optional
from tilelang.utils.language import get_buffer_region_from_load
def gemm(
A: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
mbar: Optional[tir.Buffer] = None,
mbar: tir.Buffer | None = None,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
......@@ -45,7 +45,7 @@ def gemm(
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
......@@ -63,7 +63,7 @@ def gemm(
C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
......@@ -82,7 +82,7 @@ def gemm(
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
......@@ -137,8 +137,7 @@ def gemm(
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
......@@ -175,7 +174,7 @@ def gemm(
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
......@@ -214,9 +213,9 @@ def gemm(
# experimental currently, for fast compilation
def gemm_v2(
A: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
......@@ -247,7 +246,7 @@ def gemm_v2(
AssertionError: If the K dimensions of matrices A and B don't match
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
......@@ -264,7 +263,7 @@ def gemm_v2(
B = legalize_arguments(B)
C = legalize_arguments(C)
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
......@@ -283,7 +282,7 @@ def gemm_v2(
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
......@@ -338,8 +337,7 @@ def gemm_v2(
stride_a = A_stride[-2]
stride_b = B_stride[-2]
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
......@@ -376,7 +374,7 @@ def gemm_v2(
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Union, List, Tuple, Optional
from collections import deque
from tvm import tir
from tvm.tir import Var
......@@ -80,7 +80,7 @@ def _get_current_stack() -> FrameStack:
return _local.kernel_launch_frame_stack
def _normalize_bindings(bindings: List[Var]) -> Union[Var, List[Var]]:
def _normalize_bindings(bindings: list[Var]) -> Var | list[Var]:
"""
Return a bare Var when we only have a single binding so that users may write either
`with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`.
......@@ -98,7 +98,7 @@ class KernelLaunchFrame(TIRFrame):
and handles the entry and exit of the kernel launch scope.
"""
def __enter__(self) -> Union[Var, List[Var]]:
def __enter__(self) -> Var | list[Var]:
"""
Enters the KernelLaunchFrame scope and pushes this frame onto the stack.
Returns one Var if we detect exactly 5 frames (meaning there is a single
......@@ -132,7 +132,7 @@ class KernelLaunchFrame(TIRFrame):
super().__exit__(ptype, value, trace)
@classmethod
def Current(cls) -> Optional["KernelLaunchFrame"]:
def Current(cls) -> KernelLaunchFrame | None:
"""
Returns the topmost (current) KernelLaunchFrame from the stack if it exists,
or None if the stack is empty.
......@@ -148,7 +148,7 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[dim].iter_var
return int(iter_var.dom.extent)
def get_block_extents(self) -> List[int]:
def get_block_extents(self) -> list[int]:
"""
Returns the block extents for all three dimensions.
"""
......@@ -162,7 +162,7 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent)
def get_thread_extents(self) -> List[int]:
def get_thread_extents(self) -> list[int]:
"""
Returns the thread extents for all three dimensions.
"""
......@@ -175,7 +175,7 @@ class KernelLaunchFrame(TIRFrame):
"""
return self.frames[-4 + dim].iter_var.var
def get_thread_bindings(self) -> List[Var]:
def get_thread_bindings(self) -> list[Var]:
"""
Returns the thread binding for the given dimension.
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
......@@ -198,21 +198,21 @@ class KernelLaunchFrame(TIRFrame):
"""
return self.frames[dim].iter_var.var
def get_block_bindings(self) -> List[Var]:
def get_block_bindings(self) -> list[Var]:
"""
Returns all three block bindings.
"""
return [frame.iter_var.var for frame in self.frames[0:-4]]
@property
def blocks(self) -> List[Var]:
def blocks(self) -> list[Var]:
"""
Returns the block indices from the topmost frame.
"""
return [frame.iter_var.var for frame in self.frames[0:-4]]
@property
def threads(self) -> List[Var]:
def threads(self) -> list[Var]:
"""
Returns the thread indices from the topmost frame.
"""
......@@ -227,10 +227,10 @@ class KernelLaunchFrame(TIRFrame):
def Kernel(
*blocks: List[tir.PrimExpr],
threads: Optional[Union[int, List[int], Tuple]] = None,
*blocks: list[tir.PrimExpr],
threads: int | list[int] | tuple | None = None,
is_cpu: bool = False,
prelude: Optional[str] = None,
prelude: str | None = None,
):
"""Tools to quickly construct a GPU kernel launch frame.
......@@ -310,7 +310,7 @@ def get_thread_binding(dim: int = 0) -> Var:
return KernelLaunchFrame.Current().get_thread_binding(dim)
def get_thread_bindings() -> List[Var]:
def get_thread_bindings() -> list[Var]:
"""Returns all three thread bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
......@@ -324,7 +324,7 @@ def get_block_binding(dim: int = 0) -> Var:
return KernelLaunchFrame.Current().get_block_binding(dim)
def get_block_bindings() -> List[Var]:
def get_block_bindings() -> list[Var]:
"""Returns all three block bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
......@@ -338,7 +338,7 @@ def get_thread_extent(dim: int = 0) -> int:
return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> List[int]:
def get_thread_extents() -> list[int]:
"""Returns all three thread extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
......@@ -352,7 +352,7 @@ def get_block_extent(dim: int = 0) -> int:
return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> List[int]:
def get_block_extents() -> list[int]:
"""Returns all three block extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
......
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang import language as T
from tvm.tir import Buffer, BufferRegion, BufferLoad
from tvm import tir
from typing import Union
from tilelang.utils.language import get_buffer_elems
def any_of(buffer: Union[T.Tensor, BufferRegion]):
def any_of(buffer: T.Tensor | BufferRegion):
"""Check if any element in the buffer is true.
Args:
......@@ -42,7 +42,7 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]):
raise ValueError(f"Invalid buffer type: {type(buffer)}")
def all_of(buffer: Union[T.Tensor, BufferRegion]):
def all_of(buffer: T.Tensor | BufferRegion):
"""Check if all elements in the buffer are true.
Args:
......
"""TVMScript parser overrides tailored for TileLang."""
from __future__ import annotations
from functools import partial
from typing import Tuple
from tvm.script.ir_builder import tir as T
from tvm.script.parser._core import dispatch, doc
......@@ -10,7 +10,7 @@ from tvm.tir import BufferLoad, Var
from tvm.script.parser.tir import parser as tvm_tir_parser
def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]:
def _get_node_span(node: doc.AST) -> tuple[int, int, int, int]:
"""Return the span (lineno, col, end_lineno, end_col) for a doc node."""
return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset)
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Optional, Dict, Any
from typing import Any
from tvm import tir
from tilelang import _ffi_api
def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None):
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
......@@ -22,7 +23,7 @@ def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None):
res : frame.ForFrame
The ForFrame.
"""
annotations: Dict[str, Any] = {}
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
......@@ -17,8 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration"""
from typing import Type
from __future__ import annotations
from tvm import tir
from tvm.ffi.runtime_ctypes import DataType, DataTypeCode
......@@ -28,7 +27,7 @@ from tvm.tir.expr import FloatImm
from tvm.script.parser._core import OpMethod, doc, register_op
def _register_expr_op(ty: Type): # pylint: disable=invalid-name
def _register_expr_op(ty: type): # pylint: disable=invalid-name
ty._dispatch_type = ty # pylint: disable=protected-access
def _and(a, b):
......@@ -115,7 +114,7 @@ def _register_expr_op(ty: Type): # pylint: disable=invalid-name
def _ge(a, b):
return _auto_broadcast(a, b, tir.GE)
def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name
register_op(ty, op, i)(m)
for i in [0, 1]:
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import List, Optional
from tvm import tir
from tilelang import _ffi_api
def Persistent(
domain: List[tir.PrimExpr],
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: Optional[tir.PrimExpr] = 8,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import List, Optional
from tvm import tir
from tvm.tir import IntImm
from tilelang import _ffi_api
......@@ -10,10 +10,10 @@ def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
num_stages: int = 0,
order: Optional[List[int]] = None,
stage: Optional[List[int]] = None,
sync: Optional[List[List[int]]] = None,
group: Optional[List[List[int]]] = None,
order: list[int] | None = None,
stage: list[int] | None = None,
sync: list[list[int]] | None = None,
group: list[list[int]] | None = None,
):
"""Tools to construct pipelined for loop.
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple, Union
from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING
from typing_extensions import Self
from tvm import tir
......@@ -143,7 +143,7 @@ class TensorProxy(BaseTensorProxy):
"""
@staticmethod
def _construct_strides(shape: Tuple[Any]):
def _construct_strides(shape: tuple[Any]):
s, strides = 1, [1]
for dim in shape[:0:-1]:
s *= dim
......@@ -151,7 +151,7 @@ class TensorProxy(BaseTensorProxy):
return tuple(reversed(strides))
def __call__(self,
shape: Union[Tuple[Any], PrimExpr, int],
shape: tuple[Any] | PrimExpr | int,
dtype: str = "float32",
data=None,
scope=None) -> tir.Buffer:
......@@ -172,8 +172,8 @@ class StridedTensorProxy(BaseTensorProxy):
"""
def __call__(self,
shape: Tuple[Any],
strides: Tuple[Any],
shape: tuple[Any],
strides: tuple[Any],
dtype: str = "float32",
scope=None) -> tir.Buffer:
if len(shape) != len(strides):
......@@ -270,7 +270,7 @@ else:
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
def ptr(dtype: Optional[str] = None,
def ptr(dtype: str | None = None,
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
......
"""The language interface for tl programs."""
from __future__ import annotations
from tvm import tir
from typing import Optional
from tilelang.language import copy, macro, alloc_shared
......@@ -199,7 +199,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
copy(cumsum_smem, dst)
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse: bool = False):
"""
Compute the cumulative sum of `src` along `dim`, writing results to `dst`.
......
from __future__ import annotations
import inspect
from typing import Callable, Optional, Union
from typing import Callable
import tvm.script.parser.tir.entry as _tir_entry
from tvm.tir.function import PrimFunc
from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Optional[Callable] = None,
def prim_func(func: Callable | None = None,
private: bool = False,
check_well_formed: bool = False) -> Union[PrimFunc, Callable]:
check_well_formed: bool = False) -> PrimFunc | Callable:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
......
from __future__ import annotations
import tvm.script.ir_builder.tir.ir as _ir
from tvm.script.ir_builder.tir import frame
from tvm.tir import PrimExpr
from typing import Any, Dict
from typing import Any
import tilelang.language.tir.op as _tir_op
import functools
......@@ -9,7 +10,7 @@ import functools
def serial(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement.
Parameters
......@@ -34,7 +35,7 @@ def serial(start: PrimExpr,
def parallel(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement.
Parameters
......@@ -59,7 +60,7 @@ def parallel(start: PrimExpr,
def vectorized(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement.
Parameters
......@@ -84,7 +85,7 @@ def vectorized(start: PrimExpr,
def unroll(start: PrimExpr,
stop: PrimExpr = None,
*,
annotations: Dict[str, Any] = None) -> frame.ForFrame:
annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
......@@ -111,7 +112,7 @@ def thread_binding(
stop: PrimExpr = None,
thread: str = None,
*,
annotations: Dict[str, Any] = None,
annotations: dict[str, Any] = None,
) -> frame.ForFrame:
"""The thread-binding For statement.
......
from typing import Any, Optional
from __future__ import annotations
from typing import Any
import tvm
from tvm.ir import PrimExpr
from tvm.ir.base import Span
......@@ -1857,7 +1858,7 @@ def min_value(dtype, span=None):
return _tvm_op.min_value(dtype, span)
def max_value(dtype: str, span: Optional[Span] = None) -> Any:
def max_value(dtype: str, span: Span | None = None) -> Any:
"""maximum value of dtype
Parameters
......@@ -1876,7 +1877,7 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
return _tvm_op.max_value(dtype, span)
def infinity(dtype: str, span: Optional[Span] = None) -> Any:
def infinity(dtype: str, span: Span | None = None) -> Any:
"""infinity value of dtype
Parameters
......@@ -1895,7 +1896,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any:
return _tvm_op.infinity(dtype, span)
def reinterpret(dtype, value, span: Optional[Span] = None) -> Any:
def reinterpret(dtype, value, span: Span | None = None) -> Any:
"""infinity value of dtype
Parameters
......
from __future__ import annotations
from tilelang import tvm as tvm
from typing import List
from tvm import tir
from tvm.tir import PrimExpr, Buffer, BufferLoad, op
from tilelang import language as T
......@@ -42,7 +42,7 @@ def buffer_to_tile_region(buffer: Buffer, access_type: str):
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]):
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
......@@ -69,7 +69,7 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: List[tir.PrimExpr]):
extents: list[tir.PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
......@@ -88,7 +88,7 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def index_to_coordinates(index, shape) -> List[PrimExpr]:
def index_to_coordinates(index, shape) -> list[PrimExpr]:
"""
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment