"...git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "e9f137f05d22362fc6430c8aaef9440872a85e8a"
Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability from tilelang.utils.target import check_hip_availability
from tvm import tir from tvm import tir
from typing import Union, Any, Optional from typing import Any
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
_IS_HIP_AVAILABLE = check_hip_availability() _IS_HIP_AVAILABLE = check_hip_availability()
def _normalize_index_arg(value: Optional[Union[int, PrimExpr]]) -> Optional[PrimExpr]: def _normalize_index_arg(value: int | PrimExpr | None) -> PrimExpr | None:
""" """
Normalize warp sizing arguments so both Python ints and PrimExpr values Normalize warp sizing arguments so both Python ints and PrimExpr values
are accepted uniformly. are accepted uniformly.
...@@ -183,7 +184,7 @@ def disable_warp_group_reg_alloc(): ...@@ -183,7 +184,7 @@ def disable_warp_group_reg_alloc():
return no_set_max_nreg() return no_set_max_nreg()
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]): def mbarrier_wait_parity(mbarrier: int | PrimExpr | tir.Call, parity: int | Var):
"""Wait for memory barrier parity condition. """Wait for memory barrier parity condition.
Args: Args:
...@@ -233,7 +234,7 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union ...@@ -233,7 +234,7 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity) return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]): def mbarrier_arrive(mbarrier: int | PrimExpr | tir.Call):
"""Arrive at memory barrier. """Arrive at memory barrier.
Args: Args:
...@@ -294,7 +295,7 @@ def warpgroup_wait(num_mma: int): ...@@ -294,7 +295,7 @@ def warpgroup_wait(num_mma: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the logical lane index of the calling thread within a warp. """Return the logical lane index of the calling thread within a warp.
Parameters Parameters
...@@ -319,7 +320,7 @@ def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: ...@@ -319,7 +320,7 @@ def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr) return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the canonical warp index, assuming the warp's threads are converged. """Return the canonical warp index, assuming the warp's threads are converged.
Parameters Parameters
...@@ -343,7 +344,7 @@ def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> Prim ...@@ -343,7 +344,7 @@ def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> Prim
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr) return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the canonical warp index without synchronizing the warp. """Return the canonical warp index without synchronizing the warp.
Parameters Parameters
...@@ -368,8 +369,8 @@ def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: ...@@ -368,8 +369,8 @@ def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
def get_warp_group_idx( def get_warp_group_idx(
warp_size: Optional[Union[int, PrimExpr]] = None, warp_size: int | PrimExpr | None = None,
warps_per_group: Optional[Union[int, PrimExpr]] = None, warps_per_group: int | PrimExpr | None = None,
) -> PrimExpr: ) -> PrimExpr:
"""Return the canonical warp group index for the calling thread. """Return the canonical warp group index for the calling thread.
...@@ -441,7 +442,7 @@ def wait_wgmma(id: int): ...@@ -441,7 +442,7 @@ def wait_wgmma(id: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), id) return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), id)
def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, Var, None] = None): def barrier_wait(barrier_id: int | PrimExpr | tir.Call, parity: int | Var | None = None):
"""Wait for a memory barrier to complete. """Wait for a memory barrier to complete.
Args: Args:
...@@ -456,7 +457,7 @@ def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, ...@@ -456,7 +457,7 @@ def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int,
return mbarrier_wait_parity(barrier_id, parity) return mbarrier_wait_parity(barrier_id, parity)
def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]): def barrier_arrive(barrier_id: int | PrimExpr | tir.Call):
"""Arrive at a memory barrier. """Arrive at a memory barrier.
Args: Args:
...@@ -466,7 +467,7 @@ def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]): ...@@ -466,7 +467,7 @@ def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]):
return mbarrier_arrive(barrier_id) return mbarrier_arrive(barrier_id)
def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with XOR offset. """Perform a shuffle operation with XOR offset.
Args: Args:
...@@ -483,7 +484,7 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, ...@@ -483,7 +484,7 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with down offset. """Perform a shuffle operation with down offset.
Args: Args:
...@@ -496,7 +497,7 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr ...@@ -496,7 +497,7 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with up offset. """Perform a shuffle operation with up offset.
Args: Args:
...@@ -601,7 +602,7 @@ def loop_break(): ...@@ -601,7 +602,7 @@ def loop_break():
return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break")) return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break"))
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from typing import Union, Optional, Literal from typing import Literal
from tilelang import language as T from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
from tvm import ir, tir from tvm import ir, tir
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion], def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
dst: Union[tir.Buffer, tir.BufferLoad], dst: tir.Buffer | tir.BufferLoad,
coalesced_width: Optional[int] = None, coalesced_width: int | None = None,
disable_tma: bool = False, disable_tma: bool = False,
eviction_policy: Optional[Literal["evict_normal", "evict_first", "evict_last"]] = None): eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
"""Copy data between memory regions. """Copy data between memory regions.
Args: Args:
...@@ -94,8 +95,7 @@ def c2d_im2col(img: tir.Buffer, ...@@ -94,8 +95,7 @@ def c2d_im2col(img: tir.Buffer,
stride: int, stride: int,
dilation: int, dilation: int,
pad: int, pad: int,
eviction_policy: Optional[Literal["evict_normal", "evict_first", eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
"evict_last"]] = None):
"""Perform im2col transformation for 2D convolution. """Perform im2col transformation for 2D convolution.
Args: Args:
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, op from tvm.tir import PrimExpr, Buffer, op
from typing import List, Union
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
...@@ -36,7 +36,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr: ...@@ -36,7 +36,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
return dst return dst
def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape. """Reshapes the input buffer to the specified shape.
Args: Args:
...@@ -49,9 +49,7 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: ...@@ -49,9 +49,7 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
return T.Tensor(shape, src.dtype, src.data) return T.Tensor(shape, src.dtype, src.data)
def view(src: Buffer, def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer:
shape: Union[List[PrimExpr], None] = None,
dtype: Union[str, None] = None) -> Buffer:
""" """
Return a Tensor view of the input buffer with an optional new shape and dtype. Return a Tensor view of the input buffer with an optional new shape and dtype.
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
from typing import Union
def gemm_sp( def gemm_sp(
A_sparse: Union[tir.Buffer, tir.Var], A_sparse: tir.Buffer | tir.Var,
E: Union[tir.Buffer, tir.Var], E: tir.Buffer | tir.Var,
B: Union[tir.Buffer, tir.Var], B: tir.Buffer | tir.Var,
C: Union[tir.Buffer, tir.Var], C: tir.Buffer | tir.Var,
transpose_A: bool = False, transpose_A: bool = False,
transpose_B: bool = False, transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square, policy: GemmWarpPolicy = GemmWarpPolicy.Square,
...@@ -42,7 +42,7 @@ def gemm_sp( ...@@ -42,7 +42,7 @@ def gemm_sp(
AssertionError: If the K dimensions of matrices A and B don't match AssertionError: If the K dimensions of matrices A and B don't match
""" """
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers. """Convert let-bound variables to their corresponding buffers.
Args: Args:
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tvm import tir from tvm import tir
from typing import Union
from tilelang.language import has_let_value, get_let_value from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr): def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value. """Fill a buffer or buffer region with a specified value.
Args: Args:
...@@ -21,7 +21,7 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr): ...@@ -21,7 +21,7 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value) return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
def clear(buffer: Union[tir.Buffer, tir.Var]): def clear(buffer: tir.Buffer | tir.Var):
"""Clear a buffer by filling it with zeros. """Clear a buffer by filling it with zeros.
Args: Args:
......
"""Override the LetFrame to print a message when entering the frame.""" """Override the LetFrame to print a message when entering the frame."""
from __future__ import annotations
from tvm.ffi import register_object as _register_object from tvm.ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
...@@ -6,7 +7,6 @@ from tvm.ir import Range ...@@ -6,7 +7,6 @@ from tvm.ir import Range
from tvm import DataType from tvm import DataType
from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.script.ir_builder.tir.frame import TIRFrame
from collections import deque from collections import deque
from typing import Optional
import threading import threading
...@@ -150,7 +150,7 @@ class LetFrame(TIRFrame): ...@@ -150,7 +150,7 @@ class LetFrame(TIRFrame):
super().__exit__(ptype, value, trace) super().__exit__(ptype, value, trace)
@classmethod @classmethod
def Current(cls) -> "LetFrame": def Current(cls) -> LetFrame:
"""Get the current (topmost) let frame. """Get the current (topmost) let frame.
Returns: Returns:
...@@ -198,7 +198,7 @@ def has_let_value(var: Var) -> bool: ...@@ -198,7 +198,7 @@ def has_let_value(var: Var) -> bool:
return _get_let_stack().has_value(var) return _get_let_stack().has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]: def get_let_value(var: Var) -> PrimExpr | None:
"""Get the value bound to a variable in the current let frame stack. """Get the value bound to a variable in the current let frame stack.
Args: Args:
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T import tilelang.language as T
from tvm import tir from tvm import tir
from typing import Union, List, Optional
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
def gemm( def gemm(
A: Union[tir.Buffer, tir.Var], A: tir.Buffer | tir.Var,
B: Union[tir.Buffer, tir.Var], B: tir.Buffer | tir.Var,
C: Union[tir.Buffer, tir.Var], C: tir.Buffer | tir.Var,
transpose_A: bool = False, transpose_A: bool = False,
transpose_B: bool = False, transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square, policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False, clear_accum: bool = False,
k_pack: int = 1, k_pack: int = 1,
wg_wait: int = 0, wg_wait: int = 0,
mbar: Optional[tir.Buffer] = None, mbar: tir.Buffer | None = None,
): ):
"""Perform a General Matrix Multiplication (GEMM) operation. """Perform a General Matrix Multiplication (GEMM) operation.
...@@ -45,7 +45,7 @@ def gemm( ...@@ -45,7 +45,7 @@ def gemm(
AssertionError: If the K dimensions of matrices A and B don't match AssertionError: If the K dimensions of matrices A and B don't match
""" """
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers. """Convert let-bound variables to their corresponding buffers.
Args: Args:
...@@ -63,7 +63,7 @@ def gemm( ...@@ -63,7 +63,7 @@ def gemm(
C = legalize_arguments(C) C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None mbar = legalize_arguments(mbar) if mbar is not None else None
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
return object.shape return object.shape
elif isinstance(object, tir.BufferRegion): elif isinstance(object, tir.BufferRegion):
...@@ -82,7 +82,7 @@ def gemm( ...@@ -82,7 +82,7 @@ def gemm(
raise ValueError( raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}") f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
strides = [] strides = []
stride = 1 stride = 1
...@@ -137,8 +137,7 @@ def gemm( ...@@ -137,8 +137,7 @@ def gemm(
stride_a = A_stride[-2] stride_a = A_stride[-2]
stride_b = B_stride[-2] stride_b = B_stride[-2]
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
return object.access_ptr(access_type) return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion): elif isinstance(object, tir.BufferRegion):
...@@ -175,7 +174,7 @@ def gemm( ...@@ -175,7 +174,7 @@ def gemm(
raise ValueError( raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}") f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr: def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region.""" """Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
return [0] * len(object.shape) return [0] * len(object.shape)
...@@ -214,9 +213,9 @@ def gemm( ...@@ -214,9 +213,9 @@ def gemm(
# experimental currently, for fast compilation # experimental currently, for fast compilation
def gemm_v2( def gemm_v2(
A: Union[tir.Buffer, tir.Var], A: tir.Buffer | tir.Var,
B: Union[tir.Buffer, tir.Var], B: tir.Buffer | tir.Var,
C: Union[tir.Buffer, tir.Var], C: tir.Buffer | tir.Var,
transpose_A: bool = False, transpose_A: bool = False,
transpose_B: bool = False, transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square, policy: GemmWarpPolicy = GemmWarpPolicy.Square,
...@@ -247,7 +246,7 @@ def gemm_v2( ...@@ -247,7 +246,7 @@ def gemm_v2(
AssertionError: If the K dimensions of matrices A and B don't match AssertionError: If the K dimensions of matrices A and B don't match
""" """
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers. """Convert let-bound variables to their corresponding buffers.
Args: Args:
...@@ -264,7 +263,7 @@ def gemm_v2( ...@@ -264,7 +263,7 @@ def gemm_v2(
B = legalize_arguments(B) B = legalize_arguments(B)
C = legalize_arguments(C) C = legalize_arguments(C)
def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
return object.shape return object.shape
elif isinstance(object, tir.BufferRegion): elif isinstance(object, tir.BufferRegion):
...@@ -283,7 +282,7 @@ def gemm_v2( ...@@ -283,7 +282,7 @@ def gemm_v2(
raise ValueError( raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}") f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
strides = [] strides = []
stride = 1 stride = 1
...@@ -338,8 +337,7 @@ def gemm_v2( ...@@ -338,8 +337,7 @@ def gemm_v2(
stride_a = A_stride[-2] stride_a = A_stride[-2]
stride_b = B_stride[-2] stride_b = B_stride[-2]
def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
return object.access_ptr(access_type) return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion): elif isinstance(object, tir.BufferRegion):
...@@ -376,7 +374,7 @@ def gemm_v2( ...@@ -376,7 +374,7 @@ def gemm_v2(
raise ValueError( raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}") f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr: def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region.""" """Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer): if isinstance(object, tir.Buffer):
return [0] * len(object.shape) return [0] * len(object.shape)
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from typing import Union, List, Tuple, Optional
from collections import deque from collections import deque
from tvm import tir from tvm import tir
from tvm.tir import Var from tvm.tir import Var
...@@ -80,7 +80,7 @@ def _get_current_stack() -> FrameStack: ...@@ -80,7 +80,7 @@ def _get_current_stack() -> FrameStack:
return _local.kernel_launch_frame_stack return _local.kernel_launch_frame_stack
def _normalize_bindings(bindings: List[Var]) -> Union[Var, List[Var]]: def _normalize_bindings(bindings: list[Var]) -> Var | list[Var]:
""" """
Return a bare Var when we only have a single binding so that users may write either Return a bare Var when we only have a single binding so that users may write either
`with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`. `with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`.
...@@ -98,7 +98,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -98,7 +98,7 @@ class KernelLaunchFrame(TIRFrame):
and handles the entry and exit of the kernel launch scope. and handles the entry and exit of the kernel launch scope.
""" """
def __enter__(self) -> Union[Var, List[Var]]: def __enter__(self) -> Var | list[Var]:
""" """
Enters the KernelLaunchFrame scope and pushes this frame onto the stack. Enters the KernelLaunchFrame scope and pushes this frame onto the stack.
Returns one Var if we detect exactly 5 frames (meaning there is a single Returns one Var if we detect exactly 5 frames (meaning there is a single
...@@ -132,7 +132,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -132,7 +132,7 @@ class KernelLaunchFrame(TIRFrame):
super().__exit__(ptype, value, trace) super().__exit__(ptype, value, trace)
@classmethod @classmethod
def Current(cls) -> Optional["KernelLaunchFrame"]: def Current(cls) -> KernelLaunchFrame | None:
""" """
Returns the topmost (current) KernelLaunchFrame from the stack if it exists, Returns the topmost (current) KernelLaunchFrame from the stack if it exists,
or None if the stack is empty. or None if the stack is empty.
...@@ -148,7 +148,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -148,7 +148,7 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[dim].iter_var iter_var = self.frames[dim].iter_var
return int(iter_var.dom.extent) return int(iter_var.dom.extent)
def get_block_extents(self) -> List[int]: def get_block_extents(self) -> list[int]:
""" """
Returns the block extents for all three dimensions. Returns the block extents for all three dimensions.
""" """
...@@ -162,7 +162,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -162,7 +162,7 @@ class KernelLaunchFrame(TIRFrame):
iter_var = self.frames[-4 + dim].iter_var iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent) return int(iter_var.dom.extent)
def get_thread_extents(self) -> List[int]: def get_thread_extents(self) -> list[int]:
""" """
Returns the thread extents for all three dimensions. Returns the thread extents for all three dimensions.
""" """
...@@ -175,7 +175,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -175,7 +175,7 @@ class KernelLaunchFrame(TIRFrame):
""" """
return self.frames[-4 + dim].iter_var.var return self.frames[-4 + dim].iter_var.var
def get_thread_bindings(self) -> List[Var]: def get_thread_bindings(self) -> list[Var]:
""" """
Returns the thread binding for the given dimension. Returns the thread binding for the given dimension.
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z. dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
...@@ -198,21 +198,21 @@ class KernelLaunchFrame(TIRFrame): ...@@ -198,21 +198,21 @@ class KernelLaunchFrame(TIRFrame):
""" """
return self.frames[dim].iter_var.var return self.frames[dim].iter_var.var
def get_block_bindings(self) -> List[Var]: def get_block_bindings(self) -> list[Var]:
""" """
Returns all three block bindings. Returns all three block bindings.
""" """
return [frame.iter_var.var for frame in self.frames[0:-4]] return [frame.iter_var.var for frame in self.frames[0:-4]]
@property @property
def blocks(self) -> List[Var]: def blocks(self) -> list[Var]:
""" """
Returns the block indices from the topmost frame. Returns the block indices from the topmost frame.
""" """
return [frame.iter_var.var for frame in self.frames[0:-4]] return [frame.iter_var.var for frame in self.frames[0:-4]]
@property @property
def threads(self) -> List[Var]: def threads(self) -> list[Var]:
""" """
Returns the thread indices from the topmost frame. Returns the thread indices from the topmost frame.
""" """
...@@ -227,10 +227,10 @@ class KernelLaunchFrame(TIRFrame): ...@@ -227,10 +227,10 @@ class KernelLaunchFrame(TIRFrame):
def Kernel( def Kernel(
*blocks: List[tir.PrimExpr], *blocks: list[tir.PrimExpr],
threads: Optional[Union[int, List[int], Tuple]] = None, threads: int | list[int] | tuple | None = None,
is_cpu: bool = False, is_cpu: bool = False,
prelude: Optional[str] = None, prelude: str | None = None,
): ):
"""Tools to quickly construct a GPU kernel launch frame. """Tools to quickly construct a GPU kernel launch frame.
...@@ -310,7 +310,7 @@ def get_thread_binding(dim: int = 0) -> Var: ...@@ -310,7 +310,7 @@ def get_thread_binding(dim: int = 0) -> Var:
return KernelLaunchFrame.Current().get_thread_binding(dim) return KernelLaunchFrame.Current().get_thread_binding(dim)
def get_thread_bindings() -> List[Var]: def get_thread_bindings() -> list[Var]:
"""Returns all three thread bindings. """Returns all three thread bindings.
""" """
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
...@@ -324,7 +324,7 @@ def get_block_binding(dim: int = 0) -> Var: ...@@ -324,7 +324,7 @@ def get_block_binding(dim: int = 0) -> Var:
return KernelLaunchFrame.Current().get_block_binding(dim) return KernelLaunchFrame.Current().get_block_binding(dim)
def get_block_bindings() -> List[Var]: def get_block_bindings() -> list[Var]:
"""Returns all three block bindings. """Returns all three block bindings.
""" """
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
...@@ -338,7 +338,7 @@ def get_thread_extent(dim: int = 0) -> int: ...@@ -338,7 +338,7 @@ def get_thread_extent(dim: int = 0) -> int:
return KernelLaunchFrame.Current().get_thread_extent(dim) return KernelLaunchFrame.Current().get_thread_extent(dim)
def get_thread_extents() -> List[int]: def get_thread_extents() -> list[int]:
"""Returns all three thread extents. """Returns all three thread extents.
""" """
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
...@@ -352,7 +352,7 @@ def get_block_extent(dim: int = 0) -> int: ...@@ -352,7 +352,7 @@ def get_block_extent(dim: int = 0) -> int:
return KernelLaunchFrame.Current().get_block_extent(dim) return KernelLaunchFrame.Current().get_block_extent(dim)
def get_block_extents() -> List[int]: def get_block_extents() -> list[int]:
"""Returns all three block extents. """Returns all three block extents.
""" """
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tilelang import language as T from tilelang import language as T
from tvm.tir import Buffer, BufferRegion, BufferLoad from tvm.tir import Buffer, BufferRegion, BufferLoad
from tvm import tir from tvm import tir
from typing import Union
from tilelang.utils.language import get_buffer_elems from tilelang.utils.language import get_buffer_elems
def any_of(buffer: Union[T.Tensor, BufferRegion]): def any_of(buffer: T.Tensor | BufferRegion):
"""Check if any element in the buffer is true. """Check if any element in the buffer is true.
Args: Args:
...@@ -42,7 +42,7 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]): ...@@ -42,7 +42,7 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]):
raise ValueError(f"Invalid buffer type: {type(buffer)}") raise ValueError(f"Invalid buffer type: {type(buffer)}")
def all_of(buffer: Union[T.Tensor, BufferRegion]): def all_of(buffer: T.Tensor | BufferRegion):
"""Check if all elements in the buffer are true. """Check if all elements in the buffer are true.
Args: Args:
......
"""TVMScript parser overrides tailored for TileLang.""" """TVMScript parser overrides tailored for TileLang."""
from __future__ import annotations
from functools import partial from functools import partial
from typing import Tuple
from tvm.script.ir_builder import tir as T from tvm.script.ir_builder import tir as T
from tvm.script.parser._core import dispatch, doc from tvm.script.parser._core import dispatch, doc
...@@ -10,7 +10,7 @@ from tvm.tir import BufferLoad, Var ...@@ -10,7 +10,7 @@ from tvm.tir import BufferLoad, Var
from tvm.script.parser.tir import parser as tvm_tir_parser from tvm.script.parser.tir import parser as tvm_tir_parser
def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]: def _get_node_span(node: doc.AST) -> tuple[int, int, int, int]:
"""Return the span (lineno, col, end_lineno, end_col) for a doc node.""" """Return the span (lineno, col, end_lineno, end_col) for a doc node."""
return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset) return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset)
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from typing import Optional, Dict, Any from typing import Any
from tvm import tir from tvm import tir
from tilelang import _ffi_api from tilelang import _ffi_api
def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None): def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop. """Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression. This can be used to create element-wise tensor expression.
...@@ -22,7 +23,7 @@ def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None): ...@@ -22,7 +23,7 @@ def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None):
res : frame.ForFrame res : frame.ForFrame
The ForFrame. The ForFrame.
""" """
annotations: Dict[str, Any] = {} annotations: dict[str, Any] = {}
if coalesced_width is not None: if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width}) annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
# This file is modified from the original version, # This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/). # which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration""" """The tir expression operation registration"""
from __future__ import annotations
from typing import Type
from tvm import tir from tvm import tir
from tvm.ffi.runtime_ctypes import DataType, DataTypeCode from tvm.ffi.runtime_ctypes import DataType, DataTypeCode
...@@ -28,7 +27,7 @@ from tvm.tir.expr import FloatImm ...@@ -28,7 +27,7 @@ from tvm.tir.expr import FloatImm
from tvm.script.parser._core import OpMethod, doc, register_op from tvm.script.parser._core import OpMethod, doc, register_op
def _register_expr_op(ty: Type): # pylint: disable=invalid-name def _register_expr_op(ty: type): # pylint: disable=invalid-name
ty._dispatch_type = ty # pylint: disable=protected-access ty._dispatch_type = ty # pylint: disable=protected-access
def _and(a, b): def _and(a, b):
...@@ -115,7 +114,7 @@ def _register_expr_op(ty: Type): # pylint: disable=invalid-name ...@@ -115,7 +114,7 @@ def _register_expr_op(ty: Type): # pylint: disable=invalid-name
def _ge(a, b): def _ge(a, b):
return _auto_broadcast(a, b, tir.GE) return _auto_broadcast(a, b, tir.GE)
def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name
register_op(ty, op, i)(m) register_op(ty, op, i)(m)
for i in [0, 1]: for i in [0, 1]:
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from typing import List, Optional
from tvm import tir from tvm import tir
from tilelang import _ffi_api from tilelang import _ffi_api
def Persistent( def Persistent(
domain: List[tir.PrimExpr], domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr, wave_size: tir.PrimExpr,
index: tir.PrimExpr, index: tir.PrimExpr,
group_size: Optional[tir.PrimExpr] = 8, group_size: tir.PrimExpr | None = 8,
): ):
"""Tools to construct persistent for loop. """Tools to construct persistent for loop.
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from typing import List, Optional
from tvm import tir from tvm import tir
from tvm.tir import IntImm from tvm.tir import IntImm
from tilelang import _ffi_api from tilelang import _ffi_api
...@@ -10,10 +10,10 @@ def Pipelined( ...@@ -10,10 +10,10 @@ def Pipelined(
start: tir.PrimExpr, start: tir.PrimExpr,
stop: tir.PrimExpr = None, stop: tir.PrimExpr = None,
num_stages: int = 0, num_stages: int = 0,
order: Optional[List[int]] = None, order: list[int] | None = None,
stage: Optional[List[int]] = None, stage: list[int] | None = None,
sync: Optional[List[List[int]]] = None, sync: list[list[int]] | None = None,
group: Optional[List[List[int]]] = None, group: list[list[int]] | None = None,
): ):
"""Tools to construct pipelined for loop. """Tools to construct pipelined for loop.
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple, Union from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING
from typing_extensions import Self from typing_extensions import Self
from tvm import tir from tvm import tir
...@@ -143,7 +143,7 @@ class TensorProxy(BaseTensorProxy): ...@@ -143,7 +143,7 @@ class TensorProxy(BaseTensorProxy):
""" """
@staticmethod @staticmethod
def _construct_strides(shape: Tuple[Any]): def _construct_strides(shape: tuple[Any]):
s, strides = 1, [1] s, strides = 1, [1]
for dim in shape[:0:-1]: for dim in shape[:0:-1]:
s *= dim s *= dim
...@@ -151,7 +151,7 @@ class TensorProxy(BaseTensorProxy): ...@@ -151,7 +151,7 @@ class TensorProxy(BaseTensorProxy):
return tuple(reversed(strides)) return tuple(reversed(strides))
def __call__(self, def __call__(self,
shape: Union[Tuple[Any], PrimExpr, int], shape: tuple[Any] | PrimExpr | int,
dtype: str = "float32", dtype: str = "float32",
data=None, data=None,
scope=None) -> tir.Buffer: scope=None) -> tir.Buffer:
...@@ -172,8 +172,8 @@ class StridedTensorProxy(BaseTensorProxy): ...@@ -172,8 +172,8 @@ class StridedTensorProxy(BaseTensorProxy):
""" """
def __call__(self, def __call__(self,
shape: Tuple[Any], shape: tuple[Any],
strides: Tuple[Any], strides: tuple[Any],
dtype: str = "float32", dtype: str = "float32",
scope=None) -> tir.Buffer: scope=None) -> tir.Buffer:
if len(shape) != len(strides): if len(shape) != len(strides):
...@@ -270,7 +270,7 @@ else: ...@@ -270,7 +270,7 @@ else:
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
def ptr(dtype: Optional[str] = None, def ptr(dtype: str | None = None,
storage_scope: str = "global", storage_scope: str = "global",
*, *,
is_size_var: bool = False) -> Var: is_size_var: bool = False) -> Var:
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tvm import tir from tvm import tir
from typing import Optional
from tilelang.language import copy, macro, alloc_shared from tilelang.language import copy, macro, alloc_shared
...@@ -199,7 +199,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - ...@@ -199,7 +199,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
copy(cumsum_smem, dst) copy(cumsum_smem, dst)
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False): def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse: bool = False):
""" """
Compute the cumulative sum of `src` along `dim`, writing results to `dst`. Compute the cumulative sum of `src` along `dim`, writing results to `dst`.
......
from __future__ import annotations
import inspect import inspect
from typing import Callable, Optional, Union from typing import Callable
import tvm.script.parser.tir.entry as _tir_entry import tvm.script.parser.tir.entry as _tir_entry
from tvm.tir.function import PrimFunc from tvm.tir.function import PrimFunc
from tvm.script.parser._core import parse, scan_macro, utils from tvm.script.parser._core import parse, scan_macro, utils
def prim_func(func: Optional[Callable] = None, def prim_func(func: Callable | None = None,
private: bool = False, private: bool = False,
check_well_formed: bool = False) -> Union[PrimFunc, Callable]: check_well_formed: bool = False) -> PrimFunc | Callable:
"""The parsing method for tir prim func, by using `@prim_func` as decorator. """The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters Parameters
......
from __future__ import annotations
import tvm.script.ir_builder.tir.ir as _ir import tvm.script.ir_builder.tir.ir as _ir
from tvm.script.ir_builder.tir import frame from tvm.script.ir_builder.tir import frame
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
from typing import Any, Dict from typing import Any
import tilelang.language.tir.op as _tir_op import tilelang.language.tir.op as _tir_op
import functools import functools
...@@ -9,7 +10,7 @@ import functools ...@@ -9,7 +10,7 @@ import functools
def serial(start: PrimExpr, def serial(start: PrimExpr,
stop: PrimExpr = None, stop: PrimExpr = None,
*, *,
annotations: Dict[str, Any] = None) -> frame.ForFrame: annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement. """The serial For statement.
Parameters Parameters
...@@ -34,7 +35,7 @@ def serial(start: PrimExpr, ...@@ -34,7 +35,7 @@ def serial(start: PrimExpr,
def parallel(start: PrimExpr, def parallel(start: PrimExpr,
stop: PrimExpr = None, stop: PrimExpr = None,
*, *,
annotations: Dict[str, Any] = None) -> frame.ForFrame: annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement. """The parallel For statement.
Parameters Parameters
...@@ -59,7 +60,7 @@ def parallel(start: PrimExpr, ...@@ -59,7 +60,7 @@ def parallel(start: PrimExpr,
def vectorized(start: PrimExpr, def vectorized(start: PrimExpr,
stop: PrimExpr = None, stop: PrimExpr = None,
*, *,
annotations: Dict[str, Any] = None) -> frame.ForFrame: annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement. """The vectorized For statement.
Parameters Parameters
...@@ -84,7 +85,7 @@ def vectorized(start: PrimExpr, ...@@ -84,7 +85,7 @@ def vectorized(start: PrimExpr,
def unroll(start: PrimExpr, def unroll(start: PrimExpr,
stop: PrimExpr = None, stop: PrimExpr = None,
*, *,
annotations: Dict[str, Any] = None) -> frame.ForFrame: annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement. """The unrolled For statement.
Parameters Parameters
...@@ -111,7 +112,7 @@ def thread_binding( ...@@ -111,7 +112,7 @@ def thread_binding(
stop: PrimExpr = None, stop: PrimExpr = None,
thread: str = None, thread: str = None,
*, *,
annotations: Dict[str, Any] = None, annotations: dict[str, Any] = None,
) -> frame.ForFrame: ) -> frame.ForFrame:
"""The thread-binding For statement. """The thread-binding For statement.
......
from typing import Any, Optional from __future__ import annotations
from typing import Any
import tvm import tvm
from tvm.ir import PrimExpr from tvm.ir import PrimExpr
from tvm.ir.base import Span from tvm.ir.base import Span
...@@ -1857,7 +1858,7 @@ def min_value(dtype, span=None): ...@@ -1857,7 +1858,7 @@ def min_value(dtype, span=None):
return _tvm_op.min_value(dtype, span) return _tvm_op.min_value(dtype, span)
def max_value(dtype: str, span: Optional[Span] = None) -> Any: def max_value(dtype: str, span: Span | None = None) -> Any:
"""maximum value of dtype """maximum value of dtype
Parameters Parameters
...@@ -1876,7 +1877,7 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any: ...@@ -1876,7 +1877,7 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
return _tvm_op.max_value(dtype, span) return _tvm_op.max_value(dtype, span)
def infinity(dtype: str, span: Optional[Span] = None) -> Any: def infinity(dtype: str, span: Span | None = None) -> Any:
"""infinity value of dtype """infinity value of dtype
Parameters Parameters
...@@ -1895,7 +1896,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any: ...@@ -1895,7 +1896,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any:
return _tvm_op.infinity(dtype, span) return _tvm_op.infinity(dtype, span)
def reinterpret(dtype, value, span: Optional[Span] = None) -> Any: def reinterpret(dtype, value, span: Span | None = None) -> Any:
"""infinity value of dtype """infinity value of dtype
Parameters Parameters
......
from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
from typing import List
from tvm import tir from tvm import tir
from tvm.tir import PrimExpr, Buffer, BufferLoad, op from tvm.tir import PrimExpr, Buffer, BufferLoad, op
from tilelang import language as T from tilelang import language as T
...@@ -42,7 +42,7 @@ def buffer_to_tile_region(buffer: Buffer, access_type: str): ...@@ -42,7 +42,7 @@ def buffer_to_tile_region(buffer: Buffer, access_type: str):
return region(T.BufferLoad(buffer, mins), access_type, *extents) return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]): def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor. """Convert a buffer load operation to a tile region descriptor.
Args: Args:
...@@ -69,7 +69,7 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List ...@@ -69,7 +69,7 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: List[tir.PrimExpr]): extents: list[tir.PrimExpr]):
"""Convert a buffer region to a tile region descriptor. """Convert a buffer region to a tile region descriptor.
Args: Args:
...@@ -88,7 +88,7 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s ...@@ -88,7 +88,7 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def index_to_coordinates(index, shape) -> List[PrimExpr]: def index_to_coordinates(index, shape) -> list[PrimExpr]:
""" """
Convert a flat (linear) index into multi-dimensional coordinates for a given shape. Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment