Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
from __future__ import annotations
from tvm import tir, IRModule from tvm import tir, IRModule
from tvm.target import Target from tvm.target import Target
import tilelang import tilelang
from tilelang.transform import PassContext from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, is_hopper from tilelang.contrib.nvcc import have_tma, is_hopper
from typing import Optional
def allow_warp_specialized(pass_ctx: Optional[PassContext] = None, def allow_warp_specialized(pass_ctx: PassContext | None = None,
target: Optional[Target] = None) -> bool: target: Target | None = None) -> bool:
# avoid circular import # avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target from tilelang.jit.adapter.utils import is_cuda_target
...@@ -19,8 +19,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -19,8 +19,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not disable_warp_specialized return not disable_warp_specialized
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None,
target: Optional[Target] = None) -> bool: target: Target | None = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if not have_tma(target): if not have_tma(target):
...@@ -29,26 +29,26 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -29,26 +29,26 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target) return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target)
def allow_fence_proxy(target: Optional[Target] = None) -> bool: def allow_fence_proxy(target: Target | None = None) -> bool:
return have_tma(target) return have_tma(target)
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: def allow_vectorize(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False) disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False)
return not disable_vectorize return not disable_vectorize
def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None) -> bool: def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
enable_global_thread_sync = pass_ctx.config.get("tir.detect_global_barrier", False) enable_global_thread_sync = pass_ctx.config.get("tir.detect_global_barrier", False)
return enable_global_thread_sync return enable_global_thread_sync
def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None, def should_enable_aggressive_merge(pass_ctx: PassContext | None = None,
target: Optional[Target] = None) -> bool: target: Target | None = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool( enable_aggressive_merge = bool(
...@@ -61,7 +61,7 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None, ...@@ -61,7 +61,7 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
return enable_aggressive_merge return enable_aggressive_merge
def should_force_let_inline(pass_ctx: Optional[PassContext] = None) -> bool: def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))
......
from __future__ import annotations
import sys import sys
import os import os
import pathlib import pathlib
...@@ -5,7 +6,6 @@ import logging ...@@ -5,7 +6,6 @@ import logging
import shutil import shutil
import glob import glob
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -170,7 +170,7 @@ class EnvVar: ...@@ -170,7 +170,7 @@ class EnvVar:
key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION") key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION")
default: str # Default value if the environment variable is not set default: str # Default value if the environment variable is not set
_forced_value: Optional[str] = None # Temporary runtime override (mainly for tests/debugging) _forced_value: str | None = None # Temporary runtime override (mainly for tests/debugging)
def get(self): def get(self):
if self._forced_value is not None: if self._forced_value is not None:
......
from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from typing import Tuple
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
from tvm.runtime import convert from tvm.runtime import convert
from typing import Optional
from .utils import ( from .utils import (
mfma_store_index_map,) mfma_store_index_map,)
lift = convert lift = convert
class MatrixCoreIntrinEmitter(object): class MatrixCoreIntrinEmitter:
""" """
To eliminate Python syntax within TIR Macro. To eliminate Python syntax within TIR Macro.
""" """
...@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object): ...@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object):
chunk: int = 16, chunk: int = 16,
reduce_k: int = 1, reduce_k: int = 1,
num_elems_per_byte: int = 1, num_elems_per_byte: int = 1,
k_pack: Optional[int] = None, k_pack: int | None = None,
is_m_first: Optional[bool] = False, is_m_first: bool | None = False,
b_preshuffle: Optional[bool] = False, b_preshuffle: bool | None = False,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
self.b_dtype = b_dtype self.b_dtype = b_dtype
...@@ -135,15 +134,15 @@ class MatrixCoreIntrinEmitter(object): ...@@ -135,15 +134,15 @@ class MatrixCoreIntrinEmitter(object):
self.micro_size_y = n_dim self.micro_size_y = n_dim
self.micro_size_k = k_dim self.micro_size_k = k_dim
def _initialize_k_pack(self, k_pack: Optional[int] = None): def _initialize_k_pack(self, k_pack: int | None = None):
if k_pack is not None: if k_pack is not None:
self.k_pack = k_pack self.k_pack = k_pack
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None: if is_m_first is not None:
self.is_m_first = is_m_first self.is_m_first = is_m_first
def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False): def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False):
if b_preshuffle is not None: if b_preshuffle is not None:
self.b_preshuffle = b_preshuffle self.b_preshuffle = b_preshuffle
...@@ -203,7 +202,7 @@ class MatrixCoreIntrinEmitter(object): ...@@ -203,7 +202,7 @@ class MatrixCoreIntrinEmitter(object):
def extract_thread_binding(self, def extract_thread_binding(self,
thread_id, thread_id,
is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
''' '''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...@@ -418,10 +417,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -418,10 +417,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
chunk: int = 16, chunk: int = 16,
reduce_k: int = 1, reduce_k: int = 1,
num_elems_per_byte: int = 1, num_elems_per_byte: int = 1,
k_pack: Optional[int] = None, k_pack: int | None = None,
is_m_first: Optional[bool] = False, is_m_first: bool | None = False,
a_preshuffle: Optional[bool] = False, a_preshuffle: bool | None = False,
b_preshuffle: Optional[bool] = False, b_preshuffle: bool | None = False,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
......
from typing import Union from __future__ import annotations
from tvm import arith, DataType from tvm import arith, DataType
import tilelang.language as T import tilelang.language as T
...@@ -163,7 +163,7 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j): ...@@ -163,7 +163,7 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16) return (i * 2 + j // 16, j % 16)
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None): def get_swizzle_layout(row_idx, col_idx, row_size, dtype: DataType | str, swizzle_bytes=None):
ana = arith.Analyzer() ana = arith.Analyzer()
if isinstance(dtype, str): if isinstance(dtype, str):
dtype = DataType(dtype) dtype = DataType(dtype)
......
from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from typing import Union, Tuple, Optional, Literal, Callable from typing import Literal, Callable
from tilelang.common import TransformKind from tilelang.common import TransformKind
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var from tvm.tir import PrimExpr, IndexMap, Buffer, Var
...@@ -25,7 +26,7 @@ from tilelang.intrinsics.mma_layout import ( ...@@ -25,7 +26,7 @@ from tilelang.intrinsics.mma_layout import (
lift = convert lift = convert
class TensorCoreIntrinEmitter(object): class TensorCoreIntrinEmitter:
""" """
To eliminate Python syntax within TIR Macro. To eliminate Python syntax within TIR Macro.
""" """
...@@ -62,8 +63,8 @@ class TensorCoreIntrinEmitter(object): ...@@ -62,8 +63,8 @@ class TensorCoreIntrinEmitter(object):
chunk: int = 16, chunk: int = 16,
reduce_k: int = 1, reduce_k: int = 1,
num_elems_per_byte: int = 1, num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False, is_m_first: bool | None = False,
thread_var: Optional[Var] = None, thread_var: Var | None = None,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
self.b_dtype = b_dtype self.b_dtype = b_dtype
...@@ -144,7 +145,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -144,7 +145,7 @@ class TensorCoreIntrinEmitter(object):
self.micro_size_x = m_dim self.micro_size_x = m_dim
self.micro_size_k = k_dim self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None: if is_m_first is not None:
self.is_m_first = is_m_first self.is_m_first = is_m_first
...@@ -167,7 +168,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -167,7 +168,7 @@ class TensorCoreIntrinEmitter(object):
def extract_thread_binding( def extract_thread_binding(
self, self,
thread_id: PrimExpr, thread_id: PrimExpr,
is_m_first: Optional[bool] = None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
""" """
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...@@ -200,7 +201,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -200,7 +201,7 @@ class TensorCoreIntrinEmitter(object):
A_local_buf: Buffer, A_local_buf: Buffer,
A_shared_buf: Buffer, A_shared_buf: Buffer,
ki: PrimExpr, ki: PrimExpr,
rk: Optional[PrimExpr] = 0): rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
chunk = self.chunk chunk = self.chunk
...@@ -264,7 +265,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -264,7 +265,7 @@ class TensorCoreIntrinEmitter(object):
B_local_buf: Buffer, B_local_buf: Buffer,
B_shared_buf: Buffer, B_shared_buf: Buffer,
ki: PrimExpr, ki: PrimExpr,
rk: Optional[PrimExpr] = 0): rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
chunk = self.chunk chunk = self.chunk
...@@ -336,7 +337,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -336,7 +337,7 @@ class TensorCoreIntrinEmitter(object):
A_local_buf: Buffer, A_local_buf: Buffer,
B_local_buf: Buffer, B_local_buf: Buffer,
C_local_buf: Buffer, C_local_buf: Buffer,
k_inner: Optional[PrimExpr] = 0): k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_cols = self.warp_cols warp_cols = self.warp_cols
local_size_a = self.local_size_a local_size_a = self.local_size_a
...@@ -518,8 +519,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -518,8 +519,7 @@ class TensorCoreIntrinEmitter(object):
else: else:
raise ValueError(f"Unsupported matrix {matrix}") raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format( assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
local_buf.scope())
if matrix_is_a: if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
...@@ -684,9 +684,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -684,9 +684,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
chunk: int = 16, chunk: int = 16,
reduce_k: int = 1, reduce_k: int = 1,
num_elems_per_byte: int = 1, num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False, is_m_first: bool | None = False,
transform_kind_a: Union[int, TransformKind] = 0, transform_kind_a: int | TransformKind = 0,
transform_kind_b: Union[int, TransformKind] = 0, transform_kind_b: int | TransformKind = 0,
): ):
super().__init__( super().__init__(
a_dtype=a_dtype, a_dtype=a_dtype,
......
from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from enum import IntEnum from enum import IntEnum
from typing import Optional, Callable from typing import Callable
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap from tvm.tir import PrimExpr, Buffer, Var, IndexMap
...@@ -86,8 +87,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -86,8 +87,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
chunk: int = 16, chunk: int = 16,
reduce_k: int = 1, reduce_k: int = 1,
num_elems_per_byte: int = 1, num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False, is_m_first: bool | None = False,
thread_var: Optional[Var] = None, thread_var: Var | None = None,
): ):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
...@@ -409,8 +410,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -409,8 +410,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i) j, i)
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format( assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
local_buf.scope())
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
......
...@@ -3,17 +3,13 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs. ...@@ -3,17 +3,13 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM. kernel adapter using TVM.
""" """
from __future__ import annotations
from typing import ( from typing import (
Any, Any,
List,
Union,
Callable, Callable,
Tuple,
overload, overload,
Literal, Literal,
Dict, # For type hinting dicts
Optional,
) )
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.jit.adapter.utils import is_metal_target from tilelang.jit.adapter.utils import is_metal_target
...@@ -33,13 +29,13 @@ logger = getLogger(__name__) ...@@ -33,13 +29,13 @@ logger = getLogger(__name__)
def compile( def compile(
func: PrimFunc = None, func: PrimFunc = None,
out_idx: Union[List[int], int, None] = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target, None] = None, target_host: str | Target | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[Union[List[str], str]] = None, compile_flags: list[str] | str | None = None,
) -> JITKernel: ) -> JITKernel:
""" """
Compile the given TileLang PrimFunc with TVM and build a JITKernel. Compile the given TileLang PrimFunc with TVM and build a JITKernel.
...@@ -85,24 +81,24 @@ def compile( ...@@ -85,24 +81,24 @@ def compile(
class _JitImplementation: class _JitImplementation:
out_idx: Optional[Union[List[int], int]] out_idx: list[int] | int | None
target: Union[str, Target] target: str | Target
target_host: Union[str, Target] target_host: str | Target
execution_backend: Literal["dlpack", "ctypes", "cython"] execution_backend: Literal["dlpack", "ctypes", "cython"]
verbose: bool verbose: bool
pass_configs: Optional[Dict[str, Any]] pass_configs: dict[str, Any] | None
debug_root_path: Optional[str] debug_root_path: str | None
compile_flags: Optional[Union[List[str], str]] compile_flags: list[str] | str | None
def __init__(self, def __init__(self,
out_idx: Any = None, out_idx: Any = None,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: Optional[str] = None, debug_root_path: str | None = None,
compile_flags: Optional[Union[List[str], str]] = None): compile_flags: list[str] | str | None = None):
""" """
Initializes the JIT compiler decorator. Initializes the JIT compiler decorator.
...@@ -155,12 +151,12 @@ class _JitImplementation: ...@@ -155,12 +151,12 @@ class _JitImplementation:
except NameError: except NameError:
self.debug_root_path = path.abspath(self.debug_root_path) self.debug_root_path = path.abspath(self.debug_root_path)
self._kernel_cache: Dict[tuple, Kernel] = {} self._kernel_cache: dict[tuple, Kernel] = {}
# This tells the type checker what the *wrapper* function will return. # This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it. # this is for linting, please do not remove it.
@overload @overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]: def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]:
... ...
@overload @overload
...@@ -235,16 +231,16 @@ class _JitImplementation: ...@@ -235,16 +231,16 @@ class _JitImplementation:
def jit( # This is the new public interface def jit( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None, func: Callable[_P, _RProg] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
out_idx: Any = None, out_idx: Any = None,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
debug_root_path: Optional[str] = None, debug_root_path: str | None = None,
compile_flags: Optional[Union[List[str], str]] = None): compile_flags: list[str] | str | None = None):
""" """
Just-In-Time (JIT) compiler decorator for TileLang functions. Just-In-Time (JIT) compiler decorator for TileLang functions.
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Callable, Optional from typing import Any, Callable
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
class BaseKernelAdapter(ABC): class BaseKernelAdapter(ABC):
func: Optional[Callable] = None func: Callable | None = None
def __init__(self, mod, params: List[KernelParam], result_idx: List[int]) -> None: def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None:
self.mod = mod self.mod = mod
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
self._post_init() self._post_init()
def _legalize_result_idx(self, result_idx: Optional[List[int]]) -> List[int]: def _legalize_result_idx(self, result_idx: list[int] | None) -> list[int]:
params = self.params params = self.params
# result_idx is a list of indices of the output tensors # result_idx is a list of indices of the output tensors
if result_idx is None: if result_idx is None:
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations
import torch import torch
from ..base import BaseKernelAdapter from ..base import BaseKernelAdapter
import ctypes import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple, Any from typing import Callable, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm.relax import TensorType from tvm.relax import TensorType
...@@ -25,32 +26,32 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -25,32 +26,32 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# Class attributes to store compiled kernel information # Class attributes to store compiled kernel information
target = "cuda" target = "cuda"
ir_module: Optional[tvm.IRModule] = None ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel # The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code # that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None kernel_global_source: str | None = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle lib: ctypes.CDLL | None = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Pass configs for the compiler # Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None pass_configs: dict[str, Any] | None = None
# Add new cache attributes # Add new cache attributes
param_dtypes: Optional[List[torch.dtype]] = None # Cache for parameter dtypes param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes
param_shapes: Optional[List[List]] = None # Cache for parameter shapes param_shapes: list[list] | None = None # Cache for parameter shapes
def __init__(self, def __init__(self,
params: List[TensorType], params: list[TensorType],
result_idx: List[int], result_idx: list[int],
target: str, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: Optional[tvm.IRModule] = None, host_mod: tvm.IRModule | None = None,
device_mod: Optional[tvm.IRModule] = None, device_mod: tvm.IRModule | None = None,
kernel_global_source: Optional[str] = None, kernel_global_source: str | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[List[str]] = None): compile_flags: list[str] | None = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -107,15 +108,15 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -107,15 +108,15 @@ class CtypesKernelAdapter(BaseKernelAdapter):
@classmethod @classmethod
def from_database(cls, def from_database(cls,
params: List[TensorType], params: list[TensorType],
result_idx: List[int], result_idx: list[int],
target: str, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[List[str]] = None): compile_flags: list[str] | None = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -155,7 +156,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -155,7 +156,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter._post_init() adapter._post_init()
return adapter return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension) Maps symbolic variables to their corresponding (id, buffer_index, dimension)
...@@ -182,7 +183,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -182,7 +183,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[stride] = (1, i, j) dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface. Converts PyTorch tensor pointers to C void pointers for ctypes interface.
...@@ -193,9 +194,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -193,9 +194,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
ctypes_args.append(ctypes.c_void_p(stream)) ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args) self.lib.call(*ctypes_args)
def _wrap_forward_from_prebuild_lib(self, def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
*ins: List[torch.Tensor],
stream: Optional[int] = None):
"""High-level wrapper for kernel execution. """High-level wrapper for kernel execution.
Handles: Handles:
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations
import ctypes import ctypes
import logging import logging
import torch import torch
from typing import List, Optional, Union, Callable, Dict, Tuple, Any from typing import Callable, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
...@@ -44,43 +45,43 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -44,43 +45,43 @@ class CythonKernelAdapter(BaseKernelAdapter):
""" """
# Class attributes to store compiled kernel information # Class attributes to store compiled kernel information
target: Union[str, Target] = "cuda" target: str | Target = "cuda"
ir_module: Optional[tvm.IRModule] = None ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel # The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code # that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None kernel_global_source: str | None = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle lib: ctypes.CDLL | None = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Maps pointer arguments to their corresponding (buffer_index, shape_dimension) # Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
ptr_map: Optional[Dict[int, str]] = None ptr_map: dict[int, str] | None = None
# Maps buffer variables to their corresponding dtypes # Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None buffer_dtype_map: dict[tir.Var, tuple[int, torch.dtype]] | None = None
# Maps buffer variables to their corresponding static shapes and strides, # Maps buffer variables to their corresponding static shapes and strides,
# e.g., { # e.g., {
# "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16) # "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16)
# } # }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None static_shape_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None
static_strides_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None static_strides_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None
# Contains contiguous buffers # Contains contiguous buffers
static_contiguous_list: Optional[List[tir.Var]] = None static_contiguous_list: list[tir.Var] | None = None
# Maps buffer variables to their corresponding devices # Maps buffer variables to their corresponding devices
buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None buffer_device_map: dict[tir.Var, tuple[int, torch.device]] | None = None
# Pass configs for the compiler # Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None pass_configs: dict[str, Any] | None = None
def __init__(self, def __init__(self,
params: List[KernelParam], params: list[KernelParam],
result_idx: List[int], result_idx: list[int],
target: Union[str, Target], target: str | Target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: Optional[tvm.IRModule] = None, host_mod: tvm.IRModule | None = None,
device_mod: Optional[tvm.IRModule] = None, device_mod: tvm.IRModule | None = None,
kernel_global_source: Optional[str] = None, kernel_global_source: str | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[List[str]] = None): compile_flags: list[str] | None = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -146,15 +147,15 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -146,15 +147,15 @@ class CythonKernelAdapter(BaseKernelAdapter):
@classmethod @classmethod
def from_database(cls, def from_database(cls,
params: List[TensorType], params: list[TensorType],
result_idx: List[int], result_idx: list[int],
target: str, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[List[str]] = None): compile_flags: list[str] | None = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -205,7 +206,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -205,7 +206,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter._post_init() adapter._post_init()
return adapter return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension) Maps symbolic variables to their corresponding (id, buffer_index, dimension)
...@@ -232,7 +233,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -232,7 +233,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[stride] = (1, i, j) dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]: def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]:
"""Extract information about buffer dtypes from the TIR function. """Extract information about buffer dtypes from the TIR function.
Maps buffer variables to their corresponding dtypes. Maps buffer variables to their corresponding dtypes.
...@@ -248,7 +249,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -248,7 +249,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
buffer_dtype_map[name] = (i, map_torch_type(dtype)) buffer_dtype_map[name] = (i, map_torch_type(dtype))
return buffer_dtype_map return buffer_dtype_map
def _process_ptr_map(self) -> Dict[int, str]: def _process_ptr_map(self) -> dict[int, str]:
"""Extract information about pointer arguments from the TIR function. """Extract information about pointer arguments from the TIR function.
Maps pointer arguments to their corresponding (buffer_index, shape_dimension) Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
...@@ -263,9 +264,9 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -263,9 +264,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
return ptr_map return ptr_map
def _process_static_buffer_infos(self) -> \ def _process_static_buffer_infos(self) -> \
Tuple[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]], tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]],
Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]], dict[tir.Var, tuple[int, list[tuple[int, int]]]],
List[Tuple[tir.Var]]]: list[tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function. """Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes. Maps buffer variables to their corresponding static shapes.
...@@ -300,7 +301,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -300,7 +301,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
static_contiguous_list.append((i, buffer.name)) static_contiguous_list.append((i, buffer.name))
return static_shape_map, static_strides_map, static_contiguous_list return static_shape_map, static_strides_map, static_contiguous_list
def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]: def _process_buffer_device(self) -> dict[tir.Var, tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function. """Extract information about buffer devices from the TIR function.
Maps buffer variables to their corresponding devices. Maps buffer variables to their corresponding devices.
...@@ -326,7 +327,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -326,7 +327,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
buffer_device_map[name] = (i, device) buffer_device_map[name] = (i, device)
return buffer_device_map return buffer_device_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface. Converts PyTorch tensor pointers to C void pointers for ctypes interface.
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations
import torch import torch
from typing import List
from tilelang.contrib.dlpack import to_pytorch_func from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter from .base import BaseKernelAdapter
...@@ -11,7 +11,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter): ...@@ -11,7 +11,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> callable: def _convert_torch_func(self) -> callable:
torch_func = to_pytorch_func(self.mod) torch_func = to_pytorch_func(self.mod)
def func(*ins: List[torch.Tensor]): def func(*ins: list[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params): if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError( raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
......
from __future__ import annotations
import ctypes import ctypes
import importlib import importlib
import logging import logging
...@@ -5,7 +6,7 @@ import os ...@@ -5,7 +6,7 @@ import os
import os.path as osp import os.path as osp
import subprocess import subprocess
import tempfile import tempfile
from typing import Any, Dict, Optional, List from typing import Any
from tvm.target import Target from tvm.target import Target
...@@ -29,21 +30,21 @@ except ImportError: ...@@ -29,21 +30,21 @@ except ImportError:
is_nvrtc_available = False is_nvrtc_available = False
class LibraryGenerator(object): class LibraryGenerator:
srcpath: Optional[str] = None srcpath: str | None = None
libpath: Optional[str] = None libpath: str | None = None
lib_code: Optional[str] = None lib_code: str | None = None
pass_configs: Optional[Dict[str, Any]] = None pass_configs: dict[str, Any] | None = None
compile_flags: Optional[List[str]] = None compile_flags: list[str] | None = None
def __init__(self, target: Target, verbose: bool = False): def __init__(self, target: Target, verbose: bool = False):
self.target = target self.target = target
self.verbose = verbose self.verbose = verbose
def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None): def assign_pass_configs(self, pass_configs: dict[str, Any] | None = None):
self.pass_configs = pass_configs self.pass_configs = pass_configs
def assign_compile_flags(self, compile_flags: Optional[List[str]] = None): def assign_compile_flags(self, compile_flags: list[str] | None = None):
if compile_flags is None: if compile_flags is None:
compile_flags = [] compile_flags = []
self.compile_flags = compile_flags self.compile_flags = compile_flags
...@@ -52,7 +53,7 @@ class LibraryGenerator(object): ...@@ -52,7 +53,7 @@ class LibraryGenerator(object):
self.lib_code = lib_code self.lib_code = lib_code
# Assume currently we only support CUDA compilation # Assume currently we only support CUDA compilation
def load_lib(self, lib_path: Optional[str] = None): def load_lib(self, lib_path: str | None = None):
if lib_path is None: if lib_path is None:
lib_path = self.libpath lib_path = self.libpath
else: else:
...@@ -185,7 +186,7 @@ class LibraryGenerator(object): ...@@ -185,7 +186,7 @@ class LibraryGenerator(object):
class PyLibraryGenerator(LibraryGenerator): class PyLibraryGenerator(LibraryGenerator):
host_func: Optional[str] = None host_func: str | None = None
culib = None culib = None
pymodule = None pymodule = None
...@@ -206,7 +207,7 @@ class PyLibraryGenerator(LibraryGenerator): ...@@ -206,7 +207,7 @@ class PyLibraryGenerator(LibraryGenerator):
def update_host_func(self, host_func: str): def update_host_func(self, host_func: str):
self.host_func = host_func self.host_func = host_func
def load_lib(self, lib_path: Optional[str] = None): def load_lib(self, lib_path: str | None = None):
if lib_path is None: if lib_path is None:
lib_path = self.libpath lib_path = self.libpath
......
from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable
import torch import torch
from tvm import tir from tvm import tir
...@@ -26,16 +27,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -26,16 +27,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
kernels = {} kernels = {}
def __init__(self, def __init__(self,
params: List[KernelParam], params: list[KernelParam],
result_idx: List[int], result_idx: list[int],
target: Union[str, Target], target: str | Target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: Optional[tvm.IRModule] = None, host_mod: tvm.IRModule | None = None,
device_mod: Optional[tvm.IRModule] = None, device_mod: tvm.IRModule | None = None,
kernel_global_source: Optional[str] = None, kernel_global_source: str | None = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[List[str]] = None): compile_flags: list[str] | None = None):
check_nvrtc_available() check_nvrtc_available()
...@@ -91,15 +92,15 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -91,15 +92,15 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
@classmethod @classmethod
def from_database(cls, def from_database(cls,
params: List[KernelParam], params: list[KernelParam],
result_idx: List[int], result_idx: list[int],
target: str, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[List[str]] = None): compile_flags: list[str] | None = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
...@@ -143,7 +144,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -143,7 +144,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter._post_init() adapter._post_init()
return adapter return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]: def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension) Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
...@@ -165,7 +166,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -165,7 +166,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[shape] = (i, j) dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
def get_kernel_source(self) -> Optional[str]: def get_kernel_source(self) -> str | None:
"""Get the CUDA kernel source code. """Get the CUDA kernel source code.
Returns Returns
...@@ -175,14 +176,12 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -175,14 +176,12 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
""" """
return self.kernel_global_source return self.kernel_global_source
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel. """Low-level function to call the compiled CUDA kernel.
""" """
return self.pymodule.call(self.kernels, *args, stream=stream) return self.pymodule.call(self.kernels, *args, stream=stream)
def _wrap_forward_from_prebuild_lib(self, def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
*ins: List[torch.Tensor],
stream: Optional[int] = None):
"""High-level wrapper for kernel execution. """High-level wrapper for kernel execution.
Handles: Handles:
...@@ -242,7 +241,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -242,7 +241,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
else: else:
return [args[i] for i in self.result_idx] return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable[..., Union[torch.Tensor, List[torch.Tensor]]]: def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]:
"""Convert to a PyTorch-compatible function. """Convert to a PyTorch-compatible function.
Returns Returns
......
from __future__ import annotations
from functools import wraps from functools import wraps
from typing import Callable, Optional, Union, List from typing import Callable
import torch import torch
from tvm import tir from tvm import tir
...@@ -14,13 +15,13 @@ class MetalKernelAdapter(BaseKernelAdapter): ...@@ -14,13 +15,13 @@ class MetalKernelAdapter(BaseKernelAdapter):
def __init__( def __init__(
self, self,
params: List[KernelParam], params: list[KernelParam],
result_idx: List[int], result_idx: list[int],
# target: Union[str, Target], # target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
# host_mod: Optional[tvm.IRModule] = None, # host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None, device_mod: tvm.IRModule | None = None,
kernel_global_source: Optional[str] = None, kernel_global_source: str | None = None,
verbose: bool = False, verbose: bool = False,
# pass_configs: Optional[Dict[str, Any]] = None, # pass_configs: Optional[Dict[str, Any]] = None,
# compile_flags: Optional[List[str]] = None # compile_flags: Optional[List[str]] = None
......
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Union, Optional, Literal, Dict from typing import Literal
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule, tir from tvm import IRModule, tir
from tvm.target import Target from tvm.target import Target
...@@ -65,11 +65,11 @@ def is_metal_target(target: Target) -> bool: ...@@ -65,11 +65,11 @@ def is_metal_target(target: Target) -> bool:
def get_annotated_mod( def get_annotated_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Optional[Union[str, Target]] = None, target_host: str | Target | None = None,
model_type: Literal["device", "host", "all"] = "all", model_type: Literal["device", "host", "all"] = "all",
) -> Union[IRModule, tuple[IRModule, IRModule]]: ) -> IRModule | tuple[IRModule, IRModule]:
# Validate model_type early # Validate model_type early
if model_type not in {"device", "host", "all"}: if model_type not in {"device", "host", "all"}:
...@@ -107,7 +107,7 @@ def get_annotated_mod( ...@@ -107,7 +107,7 @@ def get_annotated_mod(
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: Optional[Dict[str, str]] = None) -> str: def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str:
""" """
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
......
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from tilelang import tvm as tvm from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union, Any from typing import Any
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target,
...@@ -205,7 +206,7 @@ class BaseWrapper(ABC): ...@@ -205,7 +206,7 @@ class BaseWrapper(ABC):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TLCUDASourceWrapper(object): class TLCUDASourceWrapper:
_TYPE_MAP = { _TYPE_MAP = {
"float32": "float", "float32": "float",
"float16": "half_t", "float16": "half_t",
...@@ -225,33 +226,33 @@ class TLCUDASourceWrapper(object): ...@@ -225,33 +226,33 @@ class TLCUDASourceWrapper(object):
} }
backend = "tl" backend = "tl"
device_mod: Optional[IRModule] = None device_mod: IRModule | None = None
host_mod: Optional[IRModule] = None host_mod: IRModule | None = None
pass_configs: Optional[Dict[str, Any]] = None pass_configs: dict[str, Any] | None = None
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None, device_mod: IRModule | None = None,
host_mod: Optional[IRModule] = None, host_mod: IRModule | None = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.device_mod = device_mod self.device_mod = device_mod
self.host_mod = host_mod self.host_mod = host_mod
self.function_names: Optional[str] = None self.function_names: str | None = None
self.dynamic_smem_buf: Optional[int] = None self.dynamic_smem_buf: int | None = None
self.block_info: Union[List[int], Dict] = [1, 1, 1] self.block_info: list[int] | dict = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1] self.grid_info: list[int] | dict = [1, 1, 1]
self.tma_descriptor_args: Optional[Dict] = None self.tma_descriptor_args: dict | None = None
self.l2_persistent_map: Optional[Dict[str, Dict]] = {} self.l2_persistent_map: dict[str, dict] | None = {}
self.parse_source_information() self.parse_source_information()
self.srcpath: Optional[str] = None self.srcpath: str | None = None
self.libpath: Optional[str] = None self.libpath: str | None = None
self.lib_code: Optional[str] = self.update_lib_code(source) self.lib_code: str | None = self.update_lib_code(source)
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP) return pythonic_expr(expr, self._TYPE_MAP)
...@@ -293,10 +294,10 @@ class TLCUDASourceWrapper(object): ...@@ -293,10 +294,10 @@ class TLCUDASourceWrapper(object):
def func_call_args(s, def func_call_args(s,
function_args, function_args,
function_params, function_params,
desc_name_map: Optional[Dict[str, str]] = None, desc_name_map: dict[str, str] | None = None,
desc_name_var_map: Optional[Dict[str, tvm.tir.Var]] = None): desc_name_var_map: dict[str, tvm.tir.Var] | None = None):
# Extract the function call arguments matching the function definition # Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int): def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i] match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")): if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False return False
...@@ -334,8 +335,8 @@ class TLCUDASourceWrapper(object): ...@@ -334,8 +335,8 @@ class TLCUDASourceWrapper(object):
kernel_launch_code = """""" kernel_launch_code = """"""
if has_l2_persistent_map: if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
desc_name_map: Dict[str, str] = {} desc_name_map: dict[str, str] = {}
desc_name_var_map: Dict[str, tvm.tir.Var] = {} desc_name_var_map: dict[str, tvm.tir.Var] = {}
for function_name, function_info in function_informations.items(): for function_name, function_info in function_informations.items():
block_info = function_info["block_info"] block_info = function_info["block_info"]
grid_info = function_info["grid_info"] grid_info = function_info["grid_info"]
...@@ -351,14 +352,8 @@ class TLCUDASourceWrapper(object): ...@@ -351,14 +352,8 @@ class TLCUDASourceWrapper(object):
# Identify the start of the function body to insert arguments # Identify the start of the function body to insert arguments
index = code.index("{", index) index = code.index("{", index)
block_str = "dim3({}, {}, {})".format( block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})"
self._pythonic_expr(block_info[0]), grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})"
self._pythonic_expr(block_info[1]),
self._pythonic_expr(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
self._pythonic_expr(grid_info[0]), self._pythonic_expr(grid_info[1]),
self._pythonic_expr(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name) init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map kernel_launch_code += init_l2_persistent_map
...@@ -382,9 +377,8 @@ class TLCUDASourceWrapper(object): ...@@ -382,9 +377,8 @@ class TLCUDASourceWrapper(object):
args_list args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
call_args = ", ".join(args_list) call_args = ", ".join(args_list)
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format( kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n"
function_name, grid_str, block_str, smem_str, call_args) kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n"
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
if has_l2_persistent_map: if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
...@@ -415,8 +409,8 @@ class TLCUDASourceWrapper(object): ...@@ -415,8 +409,8 @@ class TLCUDASourceWrapper(object):
return init_l2_persistent_map return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str], def generate_tma_descriptor_args(self, desc_name_map: dict[str, str],
desc_name_var_map: Dict[str, tvm.tir.Var]) -> str: desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = "" tma_descripter_init = ""
if self.tma_descriptor_args is None: if self.tma_descriptor_args is None:
return tma_descripter_init return tma_descripter_init
...@@ -583,7 +577,7 @@ class TLCUDASourceWrapper(object): ...@@ -583,7 +577,7 @@ class TLCUDASourceWrapper(object):
def get_dynamic_symbolic_set(self, prim_func): def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function # Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = [] dynamic_symbolic_set: list[str] = []
def unique_push_back(name: str): def unique_push_back(name: str):
if name not in dynamic_symbolic_set: if name not in dynamic_symbolic_set:
...@@ -636,7 +630,7 @@ class TLCUDASourceWrapper(object): ...@@ -636,7 +630,7 @@ class TLCUDASourceWrapper(object):
assert function_name in self.device_mod, f"Function {function_name} not found in device module" assert function_name in self.device_mod, f"Function {function_name} not found in device module"
device_func = self.device_mod[function_name] device_func = self.device_mod[function_name]
kernel_params_cnt = len(device_func.params) kernel_params_cnt = len(device_func.params)
function_params: List[str] = None function_params: list[str] = None
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params nonlocal function_params
...@@ -670,7 +664,7 @@ class TLCUDASourceWrapper(object): ...@@ -670,7 +664,7 @@ class TLCUDASourceWrapper(object):
lib_code = self.source + init_func + host_func lib_code = self.source + init_func + host_func
return lib_code return lib_code
def get_stream_type(self) -> Dict[str, str]: def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"} return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"}
@property @property
...@@ -740,9 +734,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -740,9 +734,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None, device_mod: IRModule | None = None,
host_mod: Optional[IRModule] = None, host_mod: IRModule | None = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: dict[str, Any] | None = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def create_dispatch_func(self, code, function_informations): def create_dispatch_func(self, code, function_informations):
...@@ -772,9 +766,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -772,9 +766,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
# Format the function arguments for declaration # Format the function arguments for declaration
def_args = ", ".join([f"{arg['name']}" for arg in function_args]) def_args = ", ".join([f"{arg['name']}" for arg in function_args])
def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None): def func_call_args(s, function_args, desc_name_map: dict[str, str] | None = None):
# Extract the function call arguments matching the function definition # Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int): def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i] match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")): if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False return False
...@@ -800,7 +794,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -800,7 +794,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
call_args.append((match, "None")) call_args.append((match, "None"))
return call_args return call_args
desc_name_map: Dict[str, str] = {} desc_name_map: dict[str, str] = {}
device_index = 0 device_index = 0
kernel_launch_code = """""" kernel_launch_code = """"""
for function_name, function_info in function_informations.items(): for function_name, function_info in function_informations.items():
...@@ -837,7 +831,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -837,7 +831,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
repr(list(function_informations.keys())), def_args, kernel_launch_code) repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func return host_func
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: def generate_tma_descriptor_args(self, desc_name_map: dict[str, str]) -> str:
tma_descripter_init = "" tma_descripter_init = ""
if self.tma_descriptor_args is None: if self.tma_descriptor_args is None:
return tma_descripter_init return tma_descripter_init
...@@ -915,7 +909,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -915,7 +909,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
self.host_func = self.create_dispatch_func(code, function_informations) self.host_func = self.create_dispatch_func(code, function_informations)
return self.lib_code return self.lib_code
def get_stream_type(self) -> Dict[str, str]: def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=0", "type": "int"} return {"name": "stream=0", "type": "int"}
...@@ -948,9 +942,9 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): ...@@ -948,9 +942,9 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None, device_mod: IRModule | None = None,
host_mod: Optional[IRModule] = None, host_mod: IRModule | None = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: dict[str, Any] | None = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_init_func(self): def get_init_func(self):
...@@ -966,11 +960,11 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): ...@@ -966,11 +960,11 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
init_funcs = PREDEF_INIT_FUNC.format(call_str) init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs return init_funcs
def get_stream_type(self) -> Dict[str, str]: def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=hipStreamDefault", "type": "hipStream_t"} return {"name": "stream=hipStreamDefault", "type": "hipStream_t"}
class TLCPUSourceWrapper(object): class TLCPUSourceWrapper:
_TYPE_MAP = { _TYPE_MAP = {
"float32": "float", "float32": "float",
"float16": "half", "float16": "half",
...@@ -996,29 +990,29 @@ class TLCPUSourceWrapper(object): ...@@ -996,29 +990,29 @@ class TLCPUSourceWrapper(object):
""") """)
backend = "tl" backend = "tl"
device_mod: Optional[IRModule] = None device_mod: IRModule | None = None
host_mod: Optional[IRModule] = None host_mod: IRModule | None = None
pass_configs: Optional[Dict[str, Any]] = None pass_configs: dict[str, Any] | None = None
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None, device_mod: IRModule | None = None,
host_mod: Optional[IRModule] = None, host_mod: IRModule | None = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
self.device_mod = device_mod self.device_mod = device_mod
self.host_mod = host_mod self.host_mod = host_mod
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.function_names: Optional[str] = None self.function_names: str | None = None
self.dynamic_smem_buf: Optional[int] = None self.dynamic_smem_buf: int | None = None
self.parse_source_information() self.parse_source_information()
self.srcpath: Optional[str] = None self.srcpath: str | None = None
self.libpath: Optional[str] = None self.libpath: str | None = None
self.lib_code: Optional[str] = self.update_lib_code(source) self.lib_code: str | None = self.update_lib_code(source)
def create_call_func(self, code, function_informations): def create_call_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function # Extract the set of dynamic symbolic names used in the primary function
...@@ -1068,7 +1062,7 @@ class TLCPUSourceWrapper(object): ...@@ -1068,7 +1062,7 @@ class TLCPUSourceWrapper(object):
index = code.index("{", index) index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args)) call_args = ", ".join(func_call_args(declaration, function_args))
_call_str += "{}({})".format(function_name, call_args) _call_str += f"{function_name}({call_args})"
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
host_func = self.CALL_PREFIX.format(def_args, _call_str) host_func = self.CALL_PREFIX.format(def_args, _call_str)
...@@ -1089,7 +1083,7 @@ class TLCPUSourceWrapper(object): ...@@ -1089,7 +1083,7 @@ class TLCPUSourceWrapper(object):
def get_dynamic_symbolic_set(self, prim_func): def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function # Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = [] dynamic_symbolic_set: list[str] = []
for param in prim_func.params: for param in prim_func.params:
if param in prim_func.buffer_map: if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param] buffer = prim_func.buffer_map[param]
...@@ -1137,15 +1131,15 @@ class TLCPUSourceWrapper(object): ...@@ -1137,15 +1131,15 @@ class TLCPUSourceWrapper(object):
raise ValueError("Cannot find primary function in the module.") raise ValueError("Cannot find primary function in the module.")
class TLMetalSourceWrapper(object): class TLMetalSourceWrapper:
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None, device_mod: IRModule | None = None,
host_mod: Optional[IRModule] = None, host_mod: IRModule | None = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
...@@ -1163,11 +1157,11 @@ class TLWrapper(BaseWrapper): ...@@ -1163,11 +1157,11 @@ class TLWrapper(BaseWrapper):
""" """
A wrapper class for the TileLang backend. A wrapper class for the TileLang backend.
""" """
device_mod: Optional[IRModule] = None device_mod: IRModule | None = None
host_mod: Optional[IRModule] = None host_mod: IRModule | None = None
pass_configs: Optional[Dict[str, Any]] = None pass_configs: dict[str, Any] | None = None
target: Optional[Target] = None target: Target | None = None
lib: Optional[object] = None lib: object | None = None
def __init__(self, target: Target): def __init__(self, target: Target):
super().__init__() super().__init__()
...@@ -1179,7 +1173,7 @@ class TLWrapper(BaseWrapper): ...@@ -1179,7 +1173,7 @@ class TLWrapper(BaseWrapper):
def assign_optimized_module(self, scheduled_ir_module: IRModule): def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module self.scheduled_ir_module = scheduled_ir_module
def assign_pass_configs(self, pass_configs: Dict[str, Any]): def assign_pass_configs(self, pass_configs: dict[str, Any]):
self.pass_configs = pass_configs self.pass_configs = pass_configs
def assign_host_module(self, host_mod: IRModule): def assign_host_module(self, host_mod: IRModule):
......
from typing import Any, Callable, Dict, List, Literal, Optional, Union from __future__ import annotations
from typing import Any, Callable, Literal
from tilelang.jit.adapter.utils import is_metal_target from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target from tvm.target import Target
...@@ -17,7 +18,7 @@ import logging ...@@ -17,7 +18,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class JITKernel(object): class JITKernel:
""" """
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
...@@ -37,20 +38,20 @@ class JITKernel(object): ...@@ -37,20 +38,20 @@ class JITKernel(object):
# tuner result # tuner result
latency: float = None latency: float = None
config: Dict[str, Any] = None config: dict[str, Any] = None
ref_latency: float = None ref_latency: float = None
def __init__( def __init__(
self, self,
func: PrimFunc = None, func: PrimFunc = None,
out_idx: Union[List[int], int] = None, out_idx: list[int] | int = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
from_database: bool = False, from_database: bool = False,
compile_flags: Optional[List[str]] = None, compile_flags: list[str] | None = None,
): ):
""" """
Initializes a TorchFunction instance. Initializes a TorchFunction instance.
...@@ -134,13 +135,13 @@ class JITKernel(object): ...@@ -134,13 +135,13 @@ class JITKernel(object):
func: PrimFunc, func: PrimFunc,
kernel_global_source: str, kernel_global_source: str,
kernel_lib_path: str, kernel_lib_path: str,
params: List[KernelParam], params: list[KernelParam],
target: Union[str, Target], target: str | Target,
target_host: Union[str, Target], target_host: str | Target,
out_idx: Union[List[int], int], out_idx: list[int] | int,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"], execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"],
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: dict[str, Any] | None = None,
compile_flags: Optional[List[str]] = None, compile_flags: list[str] | None = None,
): ):
""" """
Alternative constructor to create a TorchFunction directly from a database. Alternative constructor to create a TorchFunction directly from a database.
...@@ -188,7 +189,7 @@ class JITKernel(object): ...@@ -188,7 +189,7 @@ class JITKernel(object):
return self.torch_function(*args, **kwds) return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc, def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
out_idx: List[int]) -> BaseKernelAdapter: out_idx: list[int]) -> BaseKernelAdapter:
""" """
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter. Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
...@@ -291,16 +292,15 @@ class JITKernel(object): ...@@ -291,16 +292,15 @@ class JITKernel(object):
return adapter return adapter
def _create_adapter_from_database( def _create_adapter_from_database(self,
self, params: list[KernelParam],
params: List[KernelParam], result_idx: list[int] | int,
result_idx: Union[List[int], int], target: str | Target,
target: Union[str, Target], func_or_mod: PrimFunc | tvm.runtime.Module,
func_or_mod: Union[PrimFunc, tvm.runtime.Module], kernel_global_source: str,
kernel_global_source: str, kernel_lib_path: str,
kernel_lib_path: str, pass_configs: dict[str, Any] | None = None,
pass_configs: Optional[Dict[str, Any]] = None, compile_flags: list[str] | None = None) -> BaseKernelAdapter:
compile_flags: Optional[List[str]] = None) -> BaseKernelAdapter:
target = self.target target = self.target
execution_backend = self.execution_backend execution_backend = self.execution_backend
...@@ -401,11 +401,11 @@ class JITKernel(object): ...@@ -401,11 +401,11 @@ class JITKernel(object):
""" """
return str(self.artifact.host_mod) return str(self.artifact.host_mod)
def run_once(self, func: Optional[Callable] = None) -> None: def run_once(self, func: Callable | None = None) -> None:
return self.get_profiler().run_once(func) return self.get_profiler().run_once(func)
def update_tuner_result(self, latency: float, config: Dict[str, Any], def update_tuner_result(self, latency: float, config: dict[str, Any],
ref_latency: float) -> "JITKernel": ref_latency: float) -> JITKernel:
""" """
Updates the tuning results for this kernel. Updates the tuning results for this kernel.
...@@ -428,7 +428,7 @@ class JITKernel(object): ...@@ -428,7 +428,7 @@ class JITKernel(object):
return self return self
def get_tuner_result(self) -> Dict[str, Any]: def get_tuner_result(self) -> dict[str, Any]:
""" """
Gets the tuning results for this kernel. Gets the tuning results for this kernel.
...@@ -450,11 +450,11 @@ class JITKernel(object): ...@@ -450,11 +450,11 @@ class JITKernel(object):
} }
@property @property
def out_idx(self) -> List[int]: def out_idx(self) -> list[int]:
return self.adapter.result_idx return self.adapter.result_idx
@property @property
def params(self) -> List[KernelParam]: def params(self) -> list[KernelParam]:
return self.artifact.params if self.artifact else self.adapter.params return self.artifact.params if self.artifact else self.adapter.params
@property @property
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from typing import Optional
# from .parser import * # from .parser import *
# now is fully compatible with the upstream # now is fully compatible with the upstream
# tir script # tir script
...@@ -90,6 +90,6 @@ from .annotations import ( # noqa: F401 ...@@ -90,6 +90,6 @@ from .annotations import ( # noqa: F401
) )
def import_source(source: Optional[str] = None): def import_source(source: str | None = None):
# source is the source code to be imported # source is the source code to be imported
return block_attr({"pragma_import_c": source}) if source is not None else None return block_attr({"pragma_import_c": source}) if source is not None else None
"""Annotation helpers exposed on the TileLang language surface.""" """Annotation helpers exposed on the TileLang language surface."""
from __future__ import annotations
from typing import Callable, Dict from typing import Callable
from tilelang.layout import Layout from tilelang.layout import Layout
from tvm.script.parser.tir import attr, block_attr from tvm.script.parser.tir import attr, block_attr
...@@ -21,7 +22,7 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True): ...@@ -21,7 +22,7 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>") return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>")
def annotate_layout(layout_map: Dict): def annotate_layout(layout_map: dict):
"""Annotate the layout of the buffer.""" """Annotate the layout of the buffer."""
_layout_map = {} _layout_map = {}
for buffer, layout in layout_map.items(): for buffer, layout in layout_map.items():
...@@ -35,7 +36,7 @@ def annotate_layout(layout_map: Dict): ...@@ -35,7 +36,7 @@ def annotate_layout(layout_map: Dict):
return block_attr({"layout_map": _layout_map}) return block_attr({"layout_map": _layout_map})
def annotate_safe_value(safe_value_map: Dict): def annotate_safe_value(safe_value_map: dict):
"""Annotate the safe value of the buffer.""" """Annotate the safe value of the buffer."""
_safe_value_map = {} _safe_value_map = {}
for buffer, safe_value in safe_value_map.items(): for buffer, safe_value in safe_value_map.items():
...@@ -43,7 +44,7 @@ def annotate_safe_value(safe_value_map: Dict): ...@@ -43,7 +44,7 @@ def annotate_safe_value(safe_value_map: Dict):
return block_attr({"safe_value_map": _safe_value_map}) return block_attr({"safe_value_map": _safe_value_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict): def annotate_l2_hit_ratio(l2_hit_ratio_map: dict):
"""Annotate the L2 hit ratio of the buffer.""" """Annotate the L2 hit ratio of the buffer."""
_l2_hit_ratio_map = {} _l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items(): for buffer, hit_ratio in l2_hit_ratio_map.items():
......
# Copyright (c) Tile-AI Corporation. # Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
"""Atomic operations for tilelang.""" """Atomic operations for tilelang."""
from __future__ import annotations
import tilelang.language as T import tilelang.language as T
from tvm import ir, tir from tvm import ir, tir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
from typing import Optional
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load from tilelang.utils.language import get_buffer_region_from_load
...@@ -21,7 +21,7 @@ _MEMORY_ORDER_ID_MAP = { ...@@ -21,7 +21,7 @@ _MEMORY_ORDER_ID_MAP = {
def atomic_max(dst: Buffer, def atomic_max(dst: Buffer,
value: PrimExpr, value: PrimExpr,
memory_order: Optional[str] = None, memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr: return_prev: bool = False) -> PrimExpr:
""" """
Perform an atomic maximum on the value stored at dst with an optional memory-order. Perform an atomic maximum on the value stored at dst with an optional memory-order.
...@@ -67,7 +67,7 @@ def atomic_max(dst: Buffer, ...@@ -67,7 +67,7 @@ def atomic_max(dst: Buffer,
def atomic_min(dst: Buffer, def atomic_min(dst: Buffer,
value: PrimExpr, value: PrimExpr,
memory_order: Optional[str] = None, memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr: return_prev: bool = False) -> PrimExpr:
""" """
Atomically update the value at dst to the minimum of its current value and value. Atomically update the value at dst to the minimum of its current value and value.
...@@ -115,7 +115,7 @@ def atomic_min(dst: Buffer, ...@@ -115,7 +115,7 @@ def atomic_min(dst: Buffer,
def atomic_add(dst: Buffer, def atomic_add(dst: Buffer,
value: PrimExpr, value: PrimExpr,
memory_order: Optional[str] = None, memory_order: str | None = None,
return_prev: bool = False, return_prev: bool = False,
use_tma: bool = False) -> PrimExpr: use_tma: bool = False) -> PrimExpr:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment