Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
from __future__ import annotations
from tvm import tir, IRModule
from tvm.target import Target
import tilelang
from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, is_hopper
from typing import Optional
def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
def allow_warp_specialized(pass_ctx: PassContext | None = None,
target: Target | None = None) -> bool:
# avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target
......@@ -19,8 +19,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not disable_warp_specialized
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None,
target: Target | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if not have_tma(target):
......@@ -29,26 +29,26 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target)
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
def allow_fence_proxy(target: Target | None = None) -> bool:
return have_tma(target)
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
def allow_vectorize(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False)
return not disable_vectorize
def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None) -> bool:
def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_global_thread_sync = pass_ctx.config.get("tir.detect_global_barrier", False)
return enable_global_thread_sync
def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
def should_enable_aggressive_merge(pass_ctx: PassContext | None = None,
target: Target | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool(
......@@ -61,7 +61,7 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
return enable_aggressive_merge
def should_force_let_inline(pass_ctx: Optional[PassContext] = None) -> bool:
def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))
......
from __future__ import annotations
import sys
import os
import pathlib
......@@ -5,7 +6,6 @@ import logging
import shutil
import glob
from dataclasses import dataclass
from typing import Optional
logger = logging.getLogger(__name__)
......@@ -170,7 +170,7 @@ class EnvVar:
key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION")
default: str # Default value if the environment variable is not set
_forced_value: Optional[str] = None # Temporary runtime override (mainly for tests/debugging)
_forced_value: str | None = None # Temporary runtime override (mainly for tests/debugging)
def get(self):
if self._forced_value is not None:
......
from __future__ import annotations
from tilelang import tvm as tvm
import tilelang.language as T
from typing import Tuple
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.runtime import convert
from typing import Optional
from .utils import (
mfma_store_index_map,)
lift = convert
class MatrixCoreIntrinEmitter(object):
class MatrixCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
......@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object):
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
b_preshuffle: Optional[bool] = False,
k_pack: int | None = None,
is_m_first: bool | None = False,
b_preshuffle: bool | None = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
......@@ -135,15 +134,15 @@ class MatrixCoreIntrinEmitter(object):
self.micro_size_y = n_dim
self.micro_size_k = k_dim
def _initialize_k_pack(self, k_pack: Optional[int] = None):
def _initialize_k_pack(self, k_pack: int | None = None):
if k_pack is not None:
self.k_pack = k_pack
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False):
def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False):
if b_preshuffle is not None:
self.b_preshuffle = b_preshuffle
......@@ -203,7 +202,7 @@ class MatrixCoreIntrinEmitter(object):
def extract_thread_binding(self,
thread_id,
is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
......@@ -418,10 +417,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
a_preshuffle: Optional[bool] = False,
b_preshuffle: Optional[bool] = False,
k_pack: int | None = None,
is_m_first: bool | None = False,
a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False,
):
self.a_dtype = a_dtype
......
from typing import Union
from __future__ import annotations
from tvm import arith, DataType
import tilelang.language as T
......@@ -163,7 +163,7 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None):
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: DataType | str, swizzle_bytes=None):
ana = arith.Analyzer()
if isinstance(dtype, str):
dtype = DataType(dtype)
......
from __future__ import annotations
import tilelang.language as T
from typing import Union, Tuple, Optional, Literal, Callable
from typing import Literal, Callable
from tilelang.common import TransformKind
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var
......@@ -25,7 +26,7 @@ from tilelang.intrinsics.mma_layout import (
lift = convert
class TensorCoreIntrinEmitter(object):
class TensorCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
......@@ -62,8 +63,8 @@ class TensorCoreIntrinEmitter(object):
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False,
thread_var: Optional[Var] = None,
is_m_first: bool | None = False,
thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
......@@ -144,7 +145,7 @@ class TensorCoreIntrinEmitter(object):
self.micro_size_x = m_dim
self.micro_size_k = k_dim
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
......@@ -167,7 +168,7 @@ class TensorCoreIntrinEmitter(object):
def extract_thread_binding(
self,
thread_id: PrimExpr,
is_m_first: Optional[bool] = None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
......@@ -200,7 +201,7 @@ class TensorCoreIntrinEmitter(object):
A_local_buf: Buffer,
A_shared_buf: Buffer,
ki: PrimExpr,
rk: Optional[PrimExpr] = 0):
rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
......@@ -264,7 +265,7 @@ class TensorCoreIntrinEmitter(object):
B_local_buf: Buffer,
B_shared_buf: Buffer,
ki: PrimExpr,
rk: Optional[PrimExpr] = 0):
rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
......@@ -336,7 +337,7 @@ class TensorCoreIntrinEmitter(object):
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
k_inner: Optional[PrimExpr] = 0):
k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
......@@ -518,8 +519,7 @@ class TensorCoreIntrinEmitter(object):
else:
raise ValueError(f"Unsupported matrix {matrix}")
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(
local_buf.scope())
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
......@@ -684,9 +684,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False,
transform_kind_a: Union[int, TransformKind] = 0,
transform_kind_b: Union[int, TransformKind] = 0,
is_m_first: bool | None = False,
transform_kind_a: int | TransformKind = 0,
transform_kind_b: int | TransformKind = 0,
):
super().__init__(
a_dtype=a_dtype,
......
from __future__ import annotations
import tilelang.language as T
from enum import IntEnum
from typing import Optional, Callable
from typing import Callable
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap
......@@ -86,8 +87,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
is_m_first: Optional[bool] = False,
thread_var: Optional[Var] = None,
is_m_first: bool | None = False,
thread_var: Var | None = None,
):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
......@@ -409,8 +410,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(
local_buf.scope())
assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
......
......@@ -3,17 +3,13 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
from __future__ import annotations
from typing import (
Any,
List,
Union,
Callable,
Tuple,
overload,
Literal,
Dict, # For type hinting dicts
Optional,
)
from tilelang import tvm as tvm
from tilelang.jit.adapter.utils import is_metal_target
......@@ -33,13 +29,13 @@ logger = getLogger(__name__)
def compile(
func: PrimFunc = None,
out_idx: Union[List[int], int, None] = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target, None] = None,
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[Union[List[str], str]] = None,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
......@@ -85,24 +81,24 @@ def compile(
class _JitImplementation:
out_idx: Optional[Union[List[int], int]]
target: Union[str, Target]
target_host: Union[str, Target]
out_idx: list[int] | int | None
target: str | Target
target_host: str | Target
execution_backend: Literal["dlpack", "ctypes", "cython"]
verbose: bool
pass_configs: Optional[Dict[str, Any]]
debug_root_path: Optional[str]
compile_flags: Optional[Union[List[str], str]]
pass_configs: dict[str, Any] | None
debug_root_path: str | None
compile_flags: list[str] | str | None
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
compile_flags: Optional[Union[List[str], str]] = None):
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None):
"""
Initializes the JIT compiler decorator.
......@@ -155,12 +151,12 @@ class _JitImplementation:
except NameError:
self.debug_root_path = path.abspath(self.debug_root_path)
self._kernel_cache: Dict[tuple, Kernel] = {}
self._kernel_cache: dict[tuple, Kernel] = {}
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]:
...
@overload
......@@ -235,16 +231,16 @@ class _JitImplementation:
def jit( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
func: Callable[_P, _RProg] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
compile_flags: Optional[Union[List[str], str]] = None):
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List, Callable, Optional
from typing import Any, Callable
from tilelang.engine.param import KernelParam
class BaseKernelAdapter(ABC):
func: Optional[Callable] = None
func: Callable | None = None
def __init__(self, mod, params: List[KernelParam], result_idx: List[int]) -> None:
def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None:
self.mod = mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self._post_init()
def _legalize_result_idx(self, result_idx: Optional[List[int]]) -> List[int]:
def _legalize_result_idx(self, result_idx: list[int] | None) -> list[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
import torch
from ..base import BaseKernelAdapter
import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from typing import Callable, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relax import TensorType
......@@ -25,32 +26,32 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# Class attributes to store compiled kernel information
target = "cuda"
ir_module: Optional[tvm.IRModule] = None
ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code
kernel_global_source: str | None = None
lib: ctypes.CDLL | None = None # Compiled library handle
wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
pass_configs: dict[str, Any] | None = None
# Add new cache attributes
param_dtypes: Optional[List[torch.dtype]] = None # Cache for parameter dtypes
param_shapes: Optional[List[List]] = None # Cache for parameter shapes
param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes
param_shapes: list[list] | None = None # Cache for parameter shapes
def __init__(self,
params: List[TensorType],
result_idx: List[int],
params: list[TensorType],
result_idx: list[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None,
kernel_global_source: str | None = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -107,15 +108,15 @@ class CtypesKernelAdapter(BaseKernelAdapter):
@classmethod
def from_database(cls,
params: List[TensorType],
result_idx: List[int],
params: list[TensorType],
result_idx: list[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -155,7 +156,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension)
......@@ -182,7 +183,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
......@@ -193,9 +194,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
def _wrap_forward_from_prebuild_lib(self,
*ins: List[torch.Tensor],
stream: Optional[int] = None):
def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
"""High-level wrapper for kernel execution.
Handles:
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
import ctypes
import logging
import torch
from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from typing import Callable, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tilelang.engine.param import KernelParam
......@@ -44,43 +45,43 @@ class CythonKernelAdapter(BaseKernelAdapter):
"""
# Class attributes to store compiled kernel information
target: Union[str, Target] = "cuda"
ir_module: Optional[tvm.IRModule] = None
target: str | Target = "cuda"
ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code
kernel_global_source: str | None = None
lib: ctypes.CDLL | None = None # Compiled library handle
wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
ptr_map: Optional[Dict[int, str]] = None
ptr_map: dict[int, str] | None = None
# Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
buffer_dtype_map: dict[tir.Var, tuple[int, torch.dtype]] | None = None
# Maps buffer variables to their corresponding static shapes and strides,
# e.g., {
# "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16)
# }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
static_strides_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
static_shape_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None
static_strides_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None
# Contains contiguous buffers
static_contiguous_list: Optional[List[tir.Var]] = None
static_contiguous_list: list[tir.Var] | None = None
# Maps buffer variables to their corresponding devices
buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None
buffer_device_map: dict[tir.Var, tuple[int, torch.device]] | None = None
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
pass_configs: dict[str, Any] | None = None
def __init__(self,
params: List[KernelParam],
result_idx: List[int],
target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
params: list[KernelParam],
result_idx: list[int],
target: str | Target,
func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None,
kernel_global_source: str | None = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -146,15 +147,15 @@ class CythonKernelAdapter(BaseKernelAdapter):
@classmethod
def from_database(cls,
params: List[TensorType],
result_idx: List[int],
params: list[TensorType],
result_idx: list[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -205,7 +206,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension)
......@@ -232,7 +233,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]:
"""Extract information about buffer dtypes from the TIR function.
Maps buffer variables to their corresponding dtypes.
......@@ -248,7 +249,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
buffer_dtype_map[name] = (i, map_torch_type(dtype))
return buffer_dtype_map
def _process_ptr_map(self) -> Dict[int, str]:
def _process_ptr_map(self) -> dict[int, str]:
"""Extract information about pointer arguments from the TIR function.
Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
......@@ -263,9 +264,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
return ptr_map
def _process_static_buffer_infos(self) -> \
Tuple[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
List[Tuple[tir.Var]]]:
tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]],
dict[tir.Var, tuple[int, list[tuple[int, int]]]],
list[tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
......@@ -300,7 +301,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
static_contiguous_list.append((i, buffer.name))
return static_shape_map, static_strides_map, static_contiguous_list
def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
def _process_buffer_device(self) -> dict[tir.Var, tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function.
Maps buffer variables to their corresponding devices.
......@@ -326,7 +327,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
buffer_device_map[name] = (i, device)
return buffer_device_map
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
import torch
from typing import List
from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter
......@@ -11,7 +11,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> callable:
torch_func = to_pytorch_func(self.mod)
def func(*ins: List[torch.Tensor]):
def func(*ins: list[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
......
from __future__ import annotations
import ctypes
import importlib
import logging
......@@ -5,7 +6,7 @@ import os
import os.path as osp
import subprocess
import tempfile
from typing import Any, Dict, Optional, List
from typing import Any
from tvm.target import Target
......@@ -29,21 +30,21 @@ except ImportError:
is_nvrtc_available = False
class LibraryGenerator(object):
srcpath: Optional[str] = None
libpath: Optional[str] = None
lib_code: Optional[str] = None
pass_configs: Optional[Dict[str, Any]] = None
compile_flags: Optional[List[str]] = None
class LibraryGenerator:
srcpath: str | None = None
libpath: str | None = None
lib_code: str | None = None
pass_configs: dict[str, Any] | None = None
compile_flags: list[str] | None = None
def __init__(self, target: Target, verbose: bool = False):
self.target = target
self.verbose = verbose
def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None):
def assign_pass_configs(self, pass_configs: dict[str, Any] | None = None):
self.pass_configs = pass_configs
def assign_compile_flags(self, compile_flags: Optional[List[str]] = None):
def assign_compile_flags(self, compile_flags: list[str] | None = None):
if compile_flags is None:
compile_flags = []
self.compile_flags = compile_flags
......@@ -52,7 +53,7 @@ class LibraryGenerator(object):
self.lib_code = lib_code
# Assume currently we only support CUDA compilation
def load_lib(self, lib_path: Optional[str] = None):
def load_lib(self, lib_path: str | None = None):
if lib_path is None:
lib_path = self.libpath
else:
......@@ -185,7 +186,7 @@ class LibraryGenerator(object):
class PyLibraryGenerator(LibraryGenerator):
host_func: Optional[str] = None
host_func: str | None = None
culib = None
pymodule = None
......@@ -206,7 +207,7 @@ class PyLibraryGenerator(LibraryGenerator):
def update_host_func(self, host_func: str):
self.host_func = host_func
def load_lib(self, lib_path: Optional[str] = None):
def load_lib(self, lib_path: str | None = None):
if lib_path is None:
lib_path = self.libpath
......
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable
import torch
from tvm import tir
......@@ -26,16 +27,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
kernels = {}
def __init__(self,
params: List[KernelParam],
result_idx: List[int],
target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
params: list[KernelParam],
result_idx: list[int],
target: str | Target,
func_or_mod: tir.PrimFunc | tvm.IRModule,
host_mod: tvm.IRModule | None = None,
device_mod: tvm.IRModule | None = None,
kernel_global_source: str | None = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
check_nvrtc_available()
......@@ -91,15 +92,15 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
@classmethod
def from_database(cls,
params: List[KernelParam],
result_idx: List[int],
params: list[KernelParam],
result_idx: list[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None):
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
......@@ -143,7 +144,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
......@@ -165,7 +166,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map
def get_kernel_source(self) -> Optional[str]:
def get_kernel_source(self) -> str | None:
"""Get the CUDA kernel source code.
Returns
......@@ -175,14 +176,12 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
"""
return self.kernel_global_source
def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel.
"""
return self.pymodule.call(self.kernels, *args, stream=stream)
def _wrap_forward_from_prebuild_lib(self,
*ins: List[torch.Tensor],
stream: Optional[int] = None):
def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
"""High-level wrapper for kernel execution.
Handles:
......@@ -242,7 +241,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
else:
return [args[i] for i in self.result_idx]
def _convert_torch_func(self) -> Callable[..., Union[torch.Tensor, List[torch.Tensor]]]:
def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]:
"""Convert to a PyTorch-compatible function.
Returns
......
from __future__ import annotations
from functools import wraps
from typing import Callable, Optional, Union, List
from typing import Callable
import torch
from tvm import tir
......@@ -14,13 +15,13 @@ class MetalKernelAdapter(BaseKernelAdapter):
def __init__(
self,
params: List[KernelParam],
result_idx: List[int],
params: list[KernelParam],
result_idx: list[int],
# target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
func_or_mod: tir.PrimFunc | tvm.IRModule,
# host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
device_mod: tvm.IRModule | None = None,
kernel_global_source: str | None = None,
verbose: bool = False,
# pass_configs: Optional[Dict[str, Any]] = None,
# compile_flags: Optional[List[str]] = None
......
from __future__ import annotations
import re
from typing import Union, Optional, Literal, Dict
from typing import Literal
from tilelang import tvm as tvm
from tvm import IRModule, tir
from tvm.target import Target
......@@ -65,11 +65,11 @@ def is_metal_target(target: Target) -> bool:
def get_annotated_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None,
func_or_mod: tir.PrimFunc | tvm.IRModule,
target: str | Target = "auto",
target_host: str | Target | None = None,
model_type: Literal["device", "host", "all"] = "all",
) -> Union[IRModule, tuple[IRModule, IRModule]]:
) -> IRModule | tuple[IRModule, IRModule]:
# Validate model_type early
if model_type not in {"device", "host", "all"}:
......@@ -107,7 +107,7 @@ def get_annotated_mod(
return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: Optional[Dict[str, str]] = None) -> str:
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
......
from __future__ import annotations
from abc import ABC, abstractmethod
from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union, Any
from typing import Any
from tvm import IRModule
from tvm.target import Target
from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target,
......@@ -205,7 +206,7 @@ class BaseWrapper(ABC):
logger = logging.getLogger(__name__)
class TLCUDASourceWrapper(object):
class TLCUDASourceWrapper:
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
......@@ -225,33 +226,33 @@ class TLCUDASourceWrapper(object):
}
backend = "tl"
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
device_mod: IRModule | None = None
host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.pass_configs = pass_configs
self.device_mod = device_mod
self.host_mod = host_mod
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.tma_descriptor_args: Optional[Dict] = None
self.l2_persistent_map: Optional[Dict[str, Dict]] = {}
self.function_names: str | None = None
self.dynamic_smem_buf: int | None = None
self.block_info: list[int] | dict = [1, 1, 1]
self.grid_info: list[int] | dict = [1, 1, 1]
self.tma_descriptor_args: dict | None = None
self.l2_persistent_map: dict[str, dict] | None = {}
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
self.srcpath: str | None = None
self.libpath: str | None = None
self.lib_code: str | None = self.update_lib_code(source)
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP)
......@@ -293,10 +294,10 @@ class TLCUDASourceWrapper(object):
def func_call_args(s,
function_args,
function_params,
desc_name_map: Optional[Dict[str, str]] = None,
desc_name_var_map: Optional[Dict[str, tvm.tir.Var]] = None):
desc_name_map: dict[str, str] | None = None,
desc_name_var_map: dict[str, tvm.tir.Var] | None = None):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
......@@ -334,8 +335,8 @@ class TLCUDASourceWrapper(object):
kernel_launch_code = """"""
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
desc_name_map: Dict[str, str] = {}
desc_name_var_map: Dict[str, tvm.tir.Var] = {}
desc_name_map: dict[str, str] = {}
desc_name_var_map: dict[str, tvm.tir.Var] = {}
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
......@@ -351,14 +352,8 @@ class TLCUDASourceWrapper(object):
# Identify the start of the function body to insert arguments
index = code.index("{", index)
block_str = "dim3({}, {}, {})".format(
self._pythonic_expr(block_info[0]),
self._pythonic_expr(block_info[1]),
self._pythonic_expr(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
self._pythonic_expr(grid_info[0]), self._pythonic_expr(grid_info[1]),
self._pythonic_expr(grid_info[2]))
block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})"
grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})"
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
......@@ -382,9 +377,8 @@ class TLCUDASourceWrapper(object):
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
call_args = ", ".join(args_list)
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n"
kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n"
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
......@@ -415,8 +409,8 @@ class TLCUDASourceWrapper(object):
return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
desc_name_var_map: Dict[str, tvm.tir.Var]) -> str:
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str],
desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
......@@ -583,7 +577,7 @@ class TLCUDASourceWrapper(object):
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = []
dynamic_symbolic_set: list[str] = []
def unique_push_back(name: str):
if name not in dynamic_symbolic_set:
......@@ -636,7 +630,7 @@ class TLCUDASourceWrapper(object):
assert function_name in self.device_mod, f"Function {function_name} not found in device module"
device_func = self.device_mod[function_name]
kernel_params_cnt = len(device_func.params)
function_params: List[str] = None
function_params: list[str] = None
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params
......@@ -670,7 +664,7 @@ class TLCUDASourceWrapper(object):
lib_code = self.source + init_func + host_func
return lib_code
def get_stream_type(self) -> Dict[str, str]:
def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"}
@property
......@@ -740,9 +734,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def create_dispatch_func(self, code, function_informations):
......@@ -772,9 +766,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['name']}" for arg in function_args])
def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None):
def func_call_args(s, function_args, desc_name_map: dict[str, str] | None = None):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
......@@ -800,7 +794,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
call_args.append((match, "None"))
return call_args
desc_name_map: Dict[str, str] = {}
desc_name_map: dict[str, str] = {}
device_index = 0
kernel_launch_code = """"""
for function_name, function_info in function_informations.items():
......@@ -837,7 +831,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
def generate_tma_descriptor_args(self, desc_name_map: dict[str, str]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
......@@ -915,7 +909,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
self.host_func = self.create_dispatch_func(code, function_informations)
return self.lib_code
def get_stream_type(self) -> Dict[str, str]:
def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=0", "type": "int"}
......@@ -948,9 +942,9 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_init_func(self):
......@@ -966,11 +960,11 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def get_stream_type(self) -> Dict[str, str]:
def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=hipStreamDefault", "type": "hipStream_t"}
class TLCPUSourceWrapper(object):
class TLCPUSourceWrapper:
_TYPE_MAP = {
"float32": "float",
"float16": "half",
......@@ -996,29 +990,29 @@ class TLCPUSourceWrapper(object):
""")
backend = "tl"
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
device_mod: IRModule | None = None
host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.device_mod = device_mod
self.host_mod = host_mod
self.pass_configs = pass_configs
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.function_names: str | None = None
self.dynamic_smem_buf: int | None = None
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
self.srcpath: str | None = None
self.libpath: str | None = None
self.lib_code: str | None = self.update_lib_code(source)
def create_call_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
......@@ -1068,7 +1062,7 @@ class TLCPUSourceWrapper(object):
index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args))
_call_str += "{}({})".format(function_name, call_args)
_call_str += f"{function_name}({call_args})"
# Wrap the kernel dispatch logic in an external C function
host_func = self.CALL_PREFIX.format(def_args, _call_str)
......@@ -1089,7 +1083,7 @@ class TLCPUSourceWrapper(object):
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = []
dynamic_symbolic_set: list[str] = []
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
......@@ -1137,15 +1131,15 @@ class TLCPUSourceWrapper(object):
raise ValueError("Cannot find primary function in the module.")
class TLMetalSourceWrapper(object):
class TLMetalSourceWrapper:
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
device_mod: IRModule | None = None,
host_mod: IRModule | None = None,
pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
......@@ -1163,11 +1157,11 @@ class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
"""
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
target: Optional[Target] = None
lib: Optional[object] = None
device_mod: IRModule | None = None
host_mod: IRModule | None = None
pass_configs: dict[str, Any] | None = None
target: Target | None = None
lib: object | None = None
def __init__(self, target: Target):
super().__init__()
......@@ -1179,7 +1173,7 @@ class TLWrapper(BaseWrapper):
def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module
def assign_pass_configs(self, pass_configs: Dict[str, Any]):
def assign_pass_configs(self, pass_configs: dict[str, Any]):
self.pass_configs = pass_configs
def assign_host_module(self, host_mod: IRModule):
......
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from __future__ import annotations
from typing import Any, Callable, Literal
from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target
......@@ -17,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
class JITKernel(object):
class JITKernel:
"""
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
......@@ -37,20 +38,20 @@ class JITKernel(object):
# tuner result
latency: float = None
config: Dict[str, Any] = None
config: dict[str, Any] = None
ref_latency: float = None
def __init__(
self,
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
out_idx: list[int] | int = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
target: str | Target = "auto",
target_host: str | Target = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
pass_configs: dict[str, Any] | None = None,
from_database: bool = False,
compile_flags: Optional[List[str]] = None,
compile_flags: list[str] | None = None,
):
"""
Initializes a TorchFunction instance.
......@@ -134,13 +135,13 @@ class JITKernel(object):
func: PrimFunc,
kernel_global_source: str,
kernel_lib_path: str,
params: List[KernelParam],
target: Union[str, Target],
target_host: Union[str, Target],
out_idx: Union[List[int], int],
params: list[KernelParam],
target: str | Target,
target_host: str | Target,
out_idx: list[int] | int,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"],
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None,
):
"""
Alternative constructor to create a TorchFunction directly from a database.
......@@ -188,7 +189,7 @@ class JITKernel(object):
return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
out_idx: List[int]) -> BaseKernelAdapter:
out_idx: list[int]) -> BaseKernelAdapter:
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
......@@ -291,16 +292,15 @@ class JITKernel(object):
return adapter
def _create_adapter_from_database(
self,
params: List[KernelParam],
result_idx: Union[List[int], int],
target: Union[str, Target],
func_or_mod: Union[PrimFunc, tvm.runtime.Module],
kernel_global_source: str,
kernel_lib_path: str,
pass_configs: Optional[Dict[str, Any]] = None,
compile_flags: Optional[List[str]] = None) -> BaseKernelAdapter:
def _create_adapter_from_database(self,
params: list[KernelParam],
result_idx: list[int] | int,
target: str | Target,
func_or_mod: PrimFunc | tvm.runtime.Module,
kernel_global_source: str,
kernel_lib_path: str,
pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None) -> BaseKernelAdapter:
target = self.target
execution_backend = self.execution_backend
......@@ -401,11 +401,11 @@ class JITKernel(object):
"""
return str(self.artifact.host_mod)
def run_once(self, func: Optional[Callable] = None) -> None:
def run_once(self, func: Callable | None = None) -> None:
return self.get_profiler().run_once(func)
def update_tuner_result(self, latency: float, config: Dict[str, Any],
ref_latency: float) -> "JITKernel":
def update_tuner_result(self, latency: float, config: dict[str, Any],
ref_latency: float) -> JITKernel:
"""
Updates the tuning results for this kernel.
......@@ -428,7 +428,7 @@ class JITKernel(object):
return self
def get_tuner_result(self) -> Dict[str, Any]:
def get_tuner_result(self) -> dict[str, Any]:
"""
Gets the tuning results for this kernel.
......@@ -450,11 +450,11 @@ class JITKernel(object):
}
@property
def out_idx(self) -> List[int]:
def out_idx(self) -> list[int]:
return self.adapter.result_idx
@property
def params(self) -> List[KernelParam]:
def params(self) -> list[KernelParam]:
return self.artifact.params if self.artifact else self.adapter.params
@property
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Optional
# from .parser import *
# now is fully compatible with the upstream
# tir script
......@@ -90,6 +90,6 @@ from .annotations import ( # noqa: F401
)
def import_source(source: Optional[str] = None):
def import_source(source: str | None = None):
# source is the source code to be imported
return block_attr({"pragma_import_c": source}) if source is not None else None
"""Annotation helpers exposed on the TileLang language surface."""
from __future__ import annotations
from typing import Callable, Dict
from typing import Callable
from tilelang.layout import Layout
from tvm.script.parser.tir import attr, block_attr
......@@ -21,7 +22,7 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>")
def annotate_layout(layout_map: Dict):
def annotate_layout(layout_map: dict):
"""Annotate the layout of the buffer."""
_layout_map = {}
for buffer, layout in layout_map.items():
......@@ -35,7 +36,7 @@ def annotate_layout(layout_map: Dict):
return block_attr({"layout_map": _layout_map})
def annotate_safe_value(safe_value_map: Dict):
def annotate_safe_value(safe_value_map: dict):
"""Annotate the safe value of the buffer."""
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
......@@ -43,7 +44,7 @@ def annotate_safe_value(safe_value_map: Dict):
return block_attr({"safe_value_map": _safe_value_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
def annotate_l2_hit_ratio(l2_hit_ratio_map: dict):
"""Annotate the L2 hit ratio of the buffer."""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
......
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""Atomic operations for tilelang."""
from __future__ import annotations
import tilelang.language as T
from tvm import ir, tir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
from typing import Optional
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load
......@@ -21,7 +21,7 @@ _MEMORY_ORDER_ID_MAP = {
def atomic_max(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
......@@ -67,7 +67,7 @@ def atomic_max(dst: Buffer,
def atomic_min(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
"""
Atomically update the value at dst to the minimum of its current value and value.
......@@ -115,7 +115,7 @@ def atomic_min(dst: Buffer,
def atomic_add(dst: Buffer,
value: PrimExpr,
memory_order: Optional[str] = None,
memory_order: str | None = None,
return_prev: bool = False,
use_tma: bool = False) -> PrimExpr:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment