"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "c1712ee290e4280a4563d95ca60bebc53858b055"
Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
# -*- coding: utf-8 -*-
# General information about the project. # General information about the project.
project = "Tile Language <br>" project = "Tile Language <br>"
author = "Tile Lang Contributors" author = "Tile Lang Contributors"
copyright = "2025-2025, %s" % author copyright = f"2025-2025, {author}"
# Version information. # Version information.
with open("../VERSION", "r") as f: with open("../VERSION") as f:
version = f.read().strip() version = f.read().strip()
release = version release = version
......
...@@ -87,6 +87,17 @@ target-version = "py38" ...@@ -87,6 +87,17 @@ target-version = "py38"
line-length = 100 line-length = 100
output-format = "full" output-format = "full"
exclude = [
"3rdparty",
"examples/deepseek_v32/inference",
]
[tool.ruff.lint.per-file-ignores]
# Do not upgrade type hint in testing and examples.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
"testing/**.py" = ["UP", "FA"]
"examples/**.py" = ["UP", "FA"]
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
# pycodestyle # pycodestyle
...@@ -94,7 +105,7 @@ select = [ ...@@ -94,7 +105,7 @@ select = [
# Pyflakes # Pyflakes
"F", "F",
# pyupgrade # pyupgrade
# "UP", "UP", "FA",
# flake8-bugbear # flake8-bugbear
"B", "B",
# flake8-simplify # flake8-simplify
...@@ -115,6 +126,8 @@ ignore = [ ...@@ -115,6 +126,8 @@ ignore = [
"SIM108", "SIM108",
# key in dict.keys() # key in dict.keys()
"SIM118", "SIM118",
# open file w.o. ctx manager
"SIM115",
# memory leaks # memory leaks
"B019", "B019",
# zip without explicit strict # zip without explicit strict
...@@ -122,9 +135,6 @@ ignore = [ ...@@ -122,9 +135,6 @@ ignore = [
# No such file or directory # No such file or directory
"E902", "E902",
] ]
[tool.ruff.lint.per-file-ignores]
"3rdparty/**/*" = ["ALL"]
"examples/deepseek_v32/inference/**/*" = ["ALL"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
verbosity_assertions = 3 verbosity_assertions = 3
......
from __future__ import annotations
import threading import threading
from typing import List, Any, Optional from typing import Any
# Use thread local to store the stack # Use thread local to store the stack
# This is to avoid the cross-thread interference # This is to avoid the cross-thread interference
...@@ -87,7 +88,7 @@ class AutotuneInputsCapture: ...@@ -87,7 +88,7 @@ class AutotuneInputsCapture:
__slots__ = ("tensors") __slots__ = ("tensors")
def __init__(self, tensors: List[Any]): def __init__(self, tensors: list[Any]):
self.tensors = tensors self.tensors = tensors
def __enter__(self) -> None: def __enter__(self) -> None:
...@@ -118,7 +119,7 @@ def set_autotune_inputs(*args) -> AutotuneInputsCapture: ...@@ -118,7 +119,7 @@ def set_autotune_inputs(*args) -> AutotuneInputsCapture:
return AutotuneInputsCapture(tensors) return AutotuneInputsCapture(tensors)
def get_autotune_inputs() -> Optional[List[Any]]: def get_autotune_inputs() -> list[Any] | None:
""" """
Get the current autotune inputs from the stack. Get the current autotune inputs from the stack.
""" """
......
"""The auto-tune parameters. """The auto-tune parameters.
""" """
from __future__ import annotations
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tvm.target import Target from tvm.target import Target
from typing import Callable, List, Literal, Any, Optional, Union, Dict from typing import Callable, Literal, Any
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
...@@ -40,12 +41,12 @@ class CompileArgs: ...@@ -40,12 +41,12 @@ class CompileArgs:
Refer to `tilelang.PassConfigKey` for supported options. Refer to `tilelang.PassConfigKey` for supported options.
""" """
out_idx: Optional[Union[List[int], int]] = None out_idx: list[int] | int | None = None
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto' target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: Union[str, Target] = None target_host: str | Target = None
verbose: bool = False verbose: bool = False
pass_configs: Optional[Dict[str, Any]] = None pass_configs: dict[str, Any] | None = None
def compile_program(self, program: PrimFunc): def compile_program(self, program: PrimFunc):
return tilelang.compile( return tilelang.compile(
...@@ -135,12 +136,12 @@ class AutotuneResult: ...@@ -135,12 +136,12 @@ class AutotuneResult:
func: Optimized function. func: Optimized function.
kernel: Compiled kernel function. kernel: Compiled kernel function.
""" """
latency: Optional[float] = None latency: float | None = None
config: Optional[dict] = None config: dict | None = None
ref_latency: Optional[float] = None ref_latency: float | None = None
libcode: Optional[str] = None libcode: str | None = None
func: Optional[Callable] = None func: Callable | None = None
kernel: Optional[Callable] = None kernel: Callable | None = None
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False): def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
""" """
...@@ -204,9 +205,9 @@ class AutotuneResult: ...@@ -204,9 +205,9 @@ class AutotuneResult:
def _load_kernel_from_disk( def _load_kernel_from_disk(
self, self,
cache_path: Path, cache_path: Path,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
out_idx: Optional[Union[List[int], int]] = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None, pass_configs: dict = None,
func: Callable = None, func: Callable = None,
...@@ -232,14 +233,14 @@ class AutotuneResult: ...@@ -232,14 +233,14 @@ class AutotuneResult:
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
return None return None
kernel_global_source: Optional[str] = None kernel_global_source: str | None = None
kernel_params: Optional[List[KernelParam]] = None kernel_params: list[KernelParam] | None = None
try: try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
if verbose: if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
with open(wrapped_kernel_path, "r") as f: with open(wrapped_kernel_path) as f:
kernel_global_source = f.read() kernel_global_source = f.read()
except Exception as e: except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}") logger.error(f"Error loading wrapped kernel source code from disk: {e}")
...@@ -300,7 +301,7 @@ class AutotuneResult: ...@@ -300,7 +301,7 @@ class AutotuneResult:
self._save_kernel_to_disk(path, self.kernel) self._save_kernel_to_disk(path, self.kernel)
@classmethod @classmethod
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult': def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult:
if not os.path.exists(path): if not os.path.exists(path):
return None return None
...@@ -308,7 +309,7 @@ class AutotuneResult: ...@@ -308,7 +309,7 @@ class AutotuneResult:
# load best config # load best config
if verbose: if verbose:
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}") logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
with open(path / BEST_CONFIG_PATH, "r") as f: with open(path / BEST_CONFIG_PATH) as f:
config = json.load(f) config = json.load(f)
# load function # load function
...@@ -320,7 +321,7 @@ class AutotuneResult: ...@@ -320,7 +321,7 @@ class AutotuneResult:
# load latency # load latency
if verbose: if verbose:
logger.debug(f"Loading latency from file: {path / LATENCY_PATH}") logger.debug(f"Loading latency from file: {path / LATENCY_PATH}")
with open(path / LATENCY_PATH, "r") as f: with open(path / LATENCY_PATH) as f:
latency = json.load(f) latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"] latency, ref_latency = latency["latency"], latency["ref_latency"]
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search. and performance optimization through configuration search.
""" """
from __future__ import annotations
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -10,7 +11,7 @@ from tvm.tir import PrimFunc, Var ...@@ -10,7 +11,7 @@ from tvm.tir import PrimFunc, Var
from tvm.target import Target from tvm.target import Target
import inspect import inspect
from functools import partial from functools import partial
from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple) from typing import (Callable, Literal, Any, overload)
from tqdm import tqdm from tqdm import tqdm
import logging import logging
import functools import functools
...@@ -103,8 +104,8 @@ class AutoTuner: ...@@ -103,8 +104,8 @@ class AutoTuner:
compile_args = CompileArgs() compile_args = CompileArgs()
profile_args = ProfileArgs() profile_args = ProfileArgs()
_kernel_parameters: Optional[Tuple[str, ...]] = None _kernel_parameters: tuple[str, ...] | None = None
_function_parameters: Optional[Dict[str, Any]] = None _function_parameters: dict[str, Any] | None = None
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner" cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
...@@ -131,12 +132,12 @@ class AutoTuner: ...@@ -131,12 +132,12 @@ class AutoTuner:
return cls(kernel, configs) return cls(kernel, configs)
def set_compile_args(self, def set_compile_args(self,
out_idx: Union[List[int], int, None] = None, out_idx: list[int] | int | None = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto', target: Literal['auto', 'cuda', 'hip'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target_host: Union[str, Target] = None, target_host: str | Target = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: dict[str, Any] | None = None):
"""Set compilation arguments for the auto-tuner. """Set compilation arguments for the auto-tuner.
Args: Args:
...@@ -223,12 +224,12 @@ class AutoTuner: ...@@ -223,12 +224,12 @@ class AutoTuner:
return self return self
def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]): def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]):
# for cache key generation # for cache key generation
self._kernel_parameters = k_parameters self._kernel_parameters = k_parameters
self._function_parameters = f_parameters self._function_parameters = f_parameters
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]: def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process. """Generate a cache key for the auto-tuning process.
""" """
...@@ -307,8 +308,8 @@ class AutoTuner: ...@@ -307,8 +308,8 @@ class AutoTuner:
return result return result
best_latency: float = 1e8 best_latency: float = 1e8
best_config: Optional[Dict[str, Any]] = None best_config: dict[str, Any] | None = None
best_kernel: Optional[tilelang.JITKernel] = None best_kernel: tilelang.JITKernel | None = None
def _compile(**config_arg) -> tilelang.JITKernel: def _compile(**config_arg) -> tilelang.JITKernel:
compile_args = self.compile_args compile_args = self.compile_args
...@@ -591,7 +592,7 @@ class _AutoTunerImplementation: ...@@ -591,7 +592,7 @@ class _AutoTunerImplementation:
warmup: int = 25 warmup: int = 25
rep: int = 100 rep: int = 100
timeout: int = 100 timeout: int = 100
configs: Union[Dict, Callable] = None configs: dict | Callable = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None ref_prog: Callable = None
supply_prog: Callable = None supply_prog: Callable = None
...@@ -603,7 +604,7 @@ class _AutoTunerImplementation: ...@@ -603,7 +604,7 @@ class _AutoTunerImplementation:
cache_input_tensors: bool = False cache_input_tensors: bool = False
def __init__(self, def __init__(self,
configs: Union[Dict, Callable], configs: dict | Callable,
warmup: int = 25, warmup: int = 25,
rep: int = 100, rep: int = 100,
timeout: int = 100, timeout: int = 100,
...@@ -653,12 +654,12 @@ class _AutoTunerImplementation: ...@@ -653,12 +654,12 @@ class _AutoTunerImplementation:
self.cache_input_tensors = cache_input_tensors # Reuse inputs self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations # Cache for storing tuned kernel implementations
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return. # This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it. # this is for linting, please do not remove it.
@overload @overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]: def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
... ...
@overload @overload
...@@ -720,9 +721,9 @@ class _AutoTunerImplementation: ...@@ -720,9 +721,9 @@ class _AutoTunerImplementation:
def autotune( # This is the new public interface def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None, func: Callable[_P, _RProg] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
configs: Union[Dict, Callable], configs: dict | Callable,
# profile arguments # profile arguments
warmup: int = 25, warmup: int = 25,
rep: int = 100, rep: int = 100,
......
"""The cache utils with class and database persistence - Init file""" """The cache utils with class and database persistence - Init file"""
from __future__ import annotations
from typing import List, Union, Literal, Optional from typing import Literal
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tilelang.jit import JITKernel from tilelang.jit import JITKernel
...@@ -13,14 +14,14 @@ _kernel_cache_instance = KernelCache() ...@@ -13,14 +14,14 @@ _kernel_cache_instance = KernelCache()
def cached( def cached(
func: PrimFunc = None, func: PrimFunc = None,
out_idx: List[int] = None, out_idx: list[int] = None,
*args, *args,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython",
verbose: Optional[bool] = False, verbose: bool | None = False,
pass_configs: Optional[dict] = None, pass_configs: dict | None = None,
compile_flags: Optional[Union[List[str], str]] = None, compile_flags: list[str] | str | None = None,
) -> JITKernel: ) -> JITKernel:
""" """
Caches and reuses compiled kernels (using KernelCache class). Caches and reuses compiled kernels (using KernelCache class).
......
"""The cache utils with class and database persistence - KernelCache Class""" """The cache utils with class and database persistence - KernelCache Class"""
from __future__ import annotations
import json import json
import logging import logging
...@@ -7,7 +8,7 @@ import shutil ...@@ -7,7 +8,7 @@ import shutil
import threading import threading
import uuid import uuid
from hashlib import sha256 from hashlib import sha256
from typing import Callable, List, Literal, Optional, Union from typing import Callable, Literal
import cloudpickle import cloudpickle
from tvm.target import Target from tvm.target import Target
...@@ -67,13 +68,13 @@ class KernelCache: ...@@ -67,13 +68,13 @@ class KernelCache:
def _generate_key( def _generate_key(
self, self,
func: Callable, func: Callable,
out_idx: List[int], out_idx: list[int],
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
args=None, args=None,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None, compile_flags: list[str] | str | None = None,
) -> str: ) -> str:
""" """
Generates a unique hash key for caching compiled kernels. Generates a unique hash key for caching compiled kernels.
...@@ -112,14 +113,14 @@ class KernelCache: ...@@ -112,14 +113,14 @@ class KernelCache:
def cached( def cached(
self, self,
func: PrimFunc = None, func: PrimFunc = None,
out_idx: List[int] = None, out_idx: list[int] = None,
*args, *args,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None, compile_flags: list[str] | str | None = None,
) -> JITKernel: ) -> JITKernel:
""" """
Caches and reuses compiled kernels to avoid redundant compilation. Caches and reuses compiled kernels to avoid redundant compilation.
...@@ -322,15 +323,15 @@ class KernelCache: ...@@ -322,15 +323,15 @@ class KernelCache:
def _load_kernel_from_disk( def _load_kernel_from_disk(
self, self,
key: str, key: str,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Union[str, Target] = None, target_host: str | Target = None,
out_idx: List[int] = None, out_idx: list[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None, compile_flags: list[str] | str | None = None,
func: Callable = None, func: Callable = None,
verbose: bool = False, verbose: bool = False,
) -> Optional[JITKernel]: ) -> JITKernel | None:
""" """
Loads a previously compiled kernel from disk cache. Loads a previously compiled kernel from disk cache.
...@@ -355,15 +356,15 @@ class KernelCache: ...@@ -355,15 +356,15 @@ class KernelCache:
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
return None return None
kernel_global_source: Optional[str] = None kernel_global_source: str | None = None
kernel_params: Optional[List[KernelParam]] = None kernel_params: list[KernelParam] | None = None
# Load the kernel source file (optional) # Load the kernel source file (optional)
try: try:
if verbose: if verbose:
self.logger.debug( self.logger.debug(
f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
with open(wrapped_kernel_path, "r") as f: with open(wrapped_kernel_path) as f:
kernel_global_source = f.read() kernel_global_source = f.read()
except Exception as e: except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}") self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")
......
"""Analysis on TIR blocks, loops and functions.""" """Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Set, Union from __future__ import annotations
from typing_extensions import Literal from typing_extensions import Literal
from tvm import ir, tir, DataType from tvm import ir, tir, DataType
...@@ -31,7 +31,7 @@ class IterInfo: ...@@ -31,7 +31,7 @@ class IterInfo:
self.loop_rv = loop_rv self.loop_rv = loop_rv
@property @property
def dom(self) -> Union[int, tir.PrimExpr]: def dom(self) -> int | tir.PrimExpr:
"""The iteration domain of the loop.""" """The iteration domain of the loop."""
return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom
...@@ -46,14 +46,14 @@ class BlockInfo: ...@@ -46,14 +46,14 @@ class BlockInfo:
"""Information about a TIR block.""" """Information about a TIR block."""
name: str name: str
iters: List[IterInfo] iters: list[IterInfo]
block_rv: tir.schedule.BlockRV block_rv: tir.schedule.BlockRV
_reduction_block: bool _reduction_block: bool
def __init__( def __init__(
self, self,
name: str, name: str,
iters: List[IterInfo], iters: list[IterInfo],
block_rv: tir.schedule.BlockRV, block_rv: tir.schedule.BlockRV,
reduction_block: bool = False, reduction_block: bool = False,
): ):
...@@ -63,7 +63,7 @@ class BlockInfo: ...@@ -63,7 +63,7 @@ class BlockInfo:
self.iters = iters self.iters = iters
self._reduction_block = reduction_block self._reduction_block = reduction_block
def dom(self) -> List[Union[int, tir.PrimExpr]]: def dom(self) -> list[int | tir.PrimExpr]:
"""The iteration domain of the block.""" """The iteration domain of the block."""
return [i.dom for i in self.iters] return [i.dom for i in self.iters]
...@@ -118,7 +118,7 @@ class BlockInfo: ...@@ -118,7 +118,7 @@ class BlockInfo:
_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") _normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc")
def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]: def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
"""Normalize the primfunc to normal form""" """Normalize the primfunc to normal form"""
try: try:
result = _normalize_prim_func(sch) result = _normalize_prim_func(sch)
...@@ -133,7 +133,7 @@ def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]: ...@@ -133,7 +133,7 @@ def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
tir.IterVar.CommReduce: "R", tir.IterVar.CommReduce: "R",
}.get(i.iter_type, "O") }.get(i.iter_type, "O")
blocks: List[BlockInfo] = [] blocks: list[BlockInfo] = []
for block, loops, iters, is_reduction in zip(*result): for block, loops, iters, is_reduction in zip(*result):
blocks.append( blocks.append(
BlockInfo( BlockInfo(
...@@ -203,7 +203,7 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: ...@@ -203,7 +203,7 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
def collect_block_iter_vars_used_in_access_region(block: tir.Block, def collect_block_iter_vars_used_in_access_region(block: tir.Block,
region: List[ir.Range]) -> Set[tir.Var]: region: list[ir.Range]) -> set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region.""" """Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set() tir_vars = set()
for expr in region: for expr in region:
...@@ -214,7 +214,7 @@ def collect_block_iter_vars_used_in_access_region(block: tir.Block, ...@@ -214,7 +214,7 @@ def collect_block_iter_vars_used_in_access_region(block: tir.Block,
return tir_vars return tir_vars
def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]: def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> set[tir.Var]:
"""Collect the variables used in the PrimExpr.""" """Collect the variables used in the PrimExpr."""
tir_vars = set() tir_vars = set()
...@@ -259,7 +259,7 @@ def is_broadcast_epilogue( ...@@ -259,7 +259,7 @@ def is_broadcast_epilogue(
def get_reduction_blocks(sch: tir.Schedule, def get_reduction_blocks(sch: tir.Schedule,
blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]: blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]:
# Get the main computation block # Get the main computation block
def is_reduction(block: BlockRV) -> bool: def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block) block_stmt = sch.get(block)
...@@ -286,7 +286,7 @@ def get_reduction_blocks(sch: tir.Schedule, ...@@ -286,7 +286,7 @@ def get_reduction_blocks(sch: tir.Schedule,
def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int: def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int:
# gpu memory prefer 128 bits coalesced access (e.g. four banks) # gpu memory prefer 128 bits coalesced access (e.g. four banks)
# 128 bits # 128 bits
buffers: List[tir.Buffer] = [] buffers: list[tir.Buffer] = []
for read in block_stmt.reads: for read in block_stmt.reads:
buffers.append(read.buffer) buffers.append(read.buffer)
for write in block_stmt.writes: for write in block_stmt.writes:
......
from __future__ import annotations
from .arch_base import TileDevice from .arch_base import TileDevice
from .cuda import * from .cuda import *
from .cpu import * from .cpu import *
from .cdna import * from .cdna import *
from .metal import * from .metal import *
from typing import Union
from tvm.target import Target from tvm.target import Target
import torch import torch
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: def get_arch(target: str | Target = "cuda") -> TileDevice:
if isinstance(target, str): if isinstance(target, str):
target = Target(target) target = Target(target)
......
from typing import List from __future__ import annotations
class TileDevice: class TileDevice:
...@@ -14,12 +14,12 @@ class TileDevice: ...@@ -14,12 +14,12 @@ class TileDevice:
0 # The size of a warp, a group of threads that execute instructions in lockstep 0 # The size of a warp, a group of threads that execute instructions in lockstep
) )
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions self.sm_partition: int = 0 # The number of streaming multiprocessor partitions
self.transaction_size: List[int] = [ self.transaction_size: list[int] = [
0, 0,
0, 0,
] # The size of memory transactions, typically in bytes ] # The size of memory transactions, typically in bytes
self.max_smem_usage: int = 0 # The maximum shared memory usage allowed self.max_smem_usage: int = 0 # The maximum shared memory usage allowed
self.bandwidth: List[int] = [ self.bandwidth: list[int] = [
0, 0,
0, 0,
] # Bandwidth specifications, possibly including peak and sustained rates ] # Bandwidth specifications, possibly including peak and sustained rates
...@@ -29,9 +29,9 @@ class TileDevice: ...@@ -29,9 +29,9 @@ class TileDevice:
) )
self.l2_cache_size_bytes: int = 0 self.l2_cache_size_bytes: int = 0
# the number of transaction size in bytes # the number of transaction size in bytes
self.transaction_size: List[int] = [0, 0] # in bytes self.transaction_size: list[int] = [0, 0] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size # bandwidth in MB/s, will be used for recommend basic tile size
self.bandwidth: List[int] = [0, 0] self.bandwidth: list[int] = [0, 0]
def get_avaliable_tensorintrin_shapes(self): def get_avaliable_tensorintrin_shapes(self):
raise NotImplementedError() raise NotImplementedError()
from __future__ import annotations
import tvm import tvm
from tvm.target import Target from tvm.target import Target
from .arch_base import TileDevice from .arch_base import TileDevice
from typing import List, Union
def is_cdna_arch(arch: TileDevice) -> bool: def is_cdna_arch(arch: TileDevice) -> bool:
...@@ -10,7 +10,7 @@ def is_cdna_arch(arch: TileDevice) -> bool: ...@@ -10,7 +10,7 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class CDNA(TileDevice): class CDNA(TileDevice):
def __init__(self, target: Union[Target, str]): def __init__(self, target: Target | str):
if isinstance(target, str): if isinstance(target, str):
target = tvm.target.Target(target) target = tvm.target.Target(target)
self.target = target self.target = target
...@@ -27,9 +27,9 @@ class CDNA(TileDevice): ...@@ -27,9 +27,9 @@ class CDNA(TileDevice):
self.max_smem_usage: int = 2 * self.smem_cap self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4 self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
self.transaction_size: List[int] = [32, 128] # in bytes self.transaction_size: list[int] = [32, 128] # in bytes
self.bandwidth: List[int] = [1300, 14000] self.bandwidth: list[int] = [1300, 14000]
__all__ = [ __all__ = [
......
from __future__ import annotations
import tvm import tvm
from tvm.target import Target from tvm.target import Target
from .arch_base import TileDevice from .arch_base import TileDevice
from typing import List, Union
from .driver import cuda_driver from .driver import cuda_driver
...@@ -91,21 +91,21 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til ...@@ -91,21 +91,21 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
raise ValueError(f"Unsupported architecture: {arch}") raise ValueError(f"Unsupported architecture: {arch}")
class TensorInstruction(object): class TensorInstruction:
def __init__( def __init__(
self, self,
name: str, name: str,
shape: List[int], shape: list[int],
): ):
self.name: str = name self.name: str = name
# only hold the shape of M and N # only hold the shape of M and N
self.shape: List[int] = shape self.shape: list[int] = shape
class CUDA(TileDevice): class CUDA(TileDevice):
def __init__(self, target: Union[Target, str]): def __init__(self, target: Target | str):
if isinstance(target, str): if isinstance(target, str):
target = tvm.target.Target(target) target = tvm.target.Target(target)
self.target = target self.target = target
...@@ -126,15 +126,15 @@ class CUDA(TileDevice): ...@@ -126,15 +126,15 @@ class CUDA(TileDevice):
self.sm_partition: int = 4 self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
# the number of transaction size in bytes # the number of transaction size in bytes
self.transaction_size: List[int] = [32, 128] # in bytes self.transaction_size: list[int] = [32, 128] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size # bandwidth in MB/s, will be used for recommend basic tile size
# TODO(lei): find some way to get the real bandwidth # TODO(lei): find some way to get the real bandwidth
# However, the ratio of bandwidth between different devices can # However, the ratio of bandwidth between different devices can
# be similar. The bandwidth can work for another devices as well. # be similar. The bandwidth can work for another devices as well.
self.bandwidth: List[int] = [750, 12080] self.bandwidth: list[int] = [750, 12080]
# get the available tensor instructions during runtime to avoid # get the available tensor instructions during runtime to avoid
# the dependency of the tensor intrinsics registration # the dependency of the tensor intrinsics registration
self.available_tensor_instructions: List[TensorInstruction] = None self.available_tensor_instructions: list[TensorInstruction] = None
def get_avaliable_tensorintrin_shapes(self): def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = ( self.available_tensor_instructions = (
......
from __future__ import annotations
import ctypes import ctypes
import sys import sys
from typing import Optional
class cudaDeviceProp(ctypes.Structure): class cudaDeviceProp(ctypes.Structure):
...@@ -77,7 +77,7 @@ class cudaDeviceProp(ctypes.Structure): ...@@ -77,7 +77,7 @@ class cudaDeviceProp(ctypes.Structure):
] ]
def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]: def get_cuda_device_properties(device_id: int = 0) -> cudaDeviceProp | None:
if sys.platform == "win32": if sys.platform == "win32":
libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll") libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll")
...@@ -95,7 +95,7 @@ def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]: ...@@ -95,7 +95,7 @@ def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]:
raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}") raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}")
def get_device_name(device_id: int = 0) -> Optional[str]: def get_device_name(device_id: int = 0) -> str | None:
prop = get_cuda_device_properties(device_id) prop = get_cuda_device_properties(device_id)
if prop: if prop:
return prop.name.decode() return prop.name.decode()
...@@ -103,7 +103,7 @@ def get_device_name(device_id: int = 0) -> Optional[str]: ...@@ -103,7 +103,7 @@ def get_device_name(device_id: int = 0) -> Optional[str]:
raise RuntimeError("Failed to get device properties.") raise RuntimeError("Failed to get device properties.")
def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> Optional[int]: def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None:
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id) prop = get_cuda_device_properties(device_id)
if prop: if prop:
...@@ -143,7 +143,7 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int: ...@@ -143,7 +143,7 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int:
return None return None
def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> Optional[int]: def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> int | None:
""" """
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
""" """
......
from __future__ import annotations
from tvm.target import Target from tvm.target import Target
from .arch_base import TileDevice from .arch_base import TileDevice
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
# Modifications Copyright (c) Microsoft. # Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight. # The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR.""" """Common schedule strategies for TIR."""
from typing import Callable, List from __future__ import annotations
from typing import Callable
from tvm import tir from tvm import tir
from .utils import retrieve_func_from_module from .utils import retrieve_func_from_module
...@@ -28,7 +29,7 @@ from .analysis import BlockInfo ...@@ -28,7 +29,7 @@ from .analysis import BlockInfo
def get_block( def get_block(
sch: tir.Schedule, sch: tir.Schedule,
blocks: List[BlockInfo], blocks: list[BlockInfo],
name: str, name: str,
): ):
"""Get the target block from a schedule. """Get the target block from a schedule.
...@@ -56,7 +57,7 @@ def get_block( ...@@ -56,7 +57,7 @@ def get_block(
def get_output_blocks( def get_output_blocks(
sch: tir.Schedule, sch: tir.Schedule,
blocks: List[BlockInfo], blocks: list[BlockInfo],
): ):
"""Get the output blocks of a schedule. """Get the output blocks of a schedule.
...@@ -89,8 +90,8 @@ def get_output_blocks( ...@@ -89,8 +90,8 @@ def get_output_blocks(
def try_inline( def try_inline(
sch: tir.Schedule, sch: tir.Schedule,
blocks: List[BlockInfo], blocks: list[BlockInfo],
) -> List[BlockInfo]: ) -> list[BlockInfo]:
"""Try to inline as many blocks as possible, and return the remaining blocks. """Try to inline as many blocks as possible, and return the remaining blocks.
Parameters Parameters
...@@ -127,8 +128,8 @@ def try_inline( ...@@ -127,8 +128,8 @@ def try_inline(
def try_inline_contiguous_spatial( def try_inline_contiguous_spatial(
sch: tir.Schedule, sch: tir.Schedule,
block_infos: List[BlockInfo], block_infos: list[BlockInfo],
) -> List[BlockInfo]: ) -> list[BlockInfo]:
"""Try to inline contiguous spatial blocks in a schedule """Try to inline contiguous spatial blocks in a schedule
Parameters Parameters
......
# pylint: disable=missing-docstring, invalid-name # pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators.""" """A GEMM schedule rule for GPU operators."""
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional, Set, Union, Tuple, Dict
from tvm import tir from tvm import tir
from tvm.ir import Range from tvm.ir import Range
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap
...@@ -57,7 +57,7 @@ def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): ...@@ -57,7 +57,7 @@ def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
def auto_inline_producers( def auto_inline_producers(
sch: tir.Schedule, sch: tir.Schedule,
block: tir.schedule.BlockRV, block: tir.schedule.BlockRV,
skip_blocks: Optional[List[tir.schedule.BlockRV]] = None, skip_blocks: list[tir.schedule.BlockRV] | None = None,
): ):
skip_blocks = skip_blocks or [] skip_blocks = skip_blocks or []
while True: while True:
...@@ -118,7 +118,7 @@ def auto_inline_consumer_chain( ...@@ -118,7 +118,7 @@ def auto_inline_consumer_chain(
# used to match the similar region with dequantize op. # used to match the similar region with dequantize op.
def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer): def find_first_similar_region(regions: list[BufferRegion], buffer: tir.Buffer):
for region in regions: for region in regions:
if len(region.buffer.shape) == len(buffer.shape): if len(region.buffer.shape) == len(buffer.shape):
return region return region
...@@ -126,7 +126,7 @@ def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer): ...@@ -126,7 +126,7 @@ def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer):
# used to match the similar buffer with dequantize op. # used to match the similar buffer with dequantize op.
def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): def find_first_similar_buffer(regions: list[BufferRegion], buffer: tir.Buffer):
for region in regions: for region in regions:
if len(region.buffer.shape) == len(buffer.shape): if len(region.buffer.shape) == len(buffer.shape):
return region.buffer return region.buffer
...@@ -134,7 +134,7 @@ def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): ...@@ -134,7 +134,7 @@ def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer):
# find the block that required to be reindex and scope. # find the block that required to be reindex and scope.
def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]: def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> BlockRV | None:
# block that most near to the arguments # block that most near to the arguments
block = main_block block = main_block
buffer = buffer buffer = buffer
...@@ -209,11 +209,11 @@ class IterTrait: ...@@ -209,11 +209,11 @@ class IterTrait:
def make_iter_fusion_index_map( def make_iter_fusion_index_map(
traits: List[IterTrait], traits: list[IterTrait],
kind_order: List[IterKind], kind_order: list[IterKind],
) -> tir.IndexMap: ) -> tir.IndexMap:
fused_iters: Dict[IterKind, PrimExpr] = {} fused_iters: dict[IterKind, PrimExpr] = {}
input_iters: List[tir.Var] = [] input_iters: list[tir.Var] = []
for i, trait in enumerate(traits): for i, trait in enumerate(traits):
v_i = tir.Var(f"i{i}", trait.extent.dtype) v_i = tir.Var(f"i{i}", trait.extent.dtype)
input_iters.append(v_i) input_iters.append(v_i)
...@@ -226,14 +226,14 @@ def make_iter_fusion_index_map( ...@@ -226,14 +226,14 @@ def make_iter_fusion_index_map(
else: else:
fused_iters[trait.kind] = v_i fused_iters[trait.kind] = v_i
final_indices: List[tir.PrimExpr] = [ final_indices: list[tir.PrimExpr] = [
fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
] ]
return tir.IndexMap(input_iters, final_indices, None) return tir.IndexMap(input_iters, final_indices, None)
def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
"""Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
Parameters Parameters
...@@ -252,8 +252,8 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: ...@@ -252,8 +252,8 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
if len(block.reads) != 2 or len(block.writes) != 1: if len(block.reads) != 2 or len(block.writes) != 1:
return None return None
def get_access_axes(region: List[Range]) -> Set[Var]: def get_access_axes(region: list[Range]) -> set[Var]:
axes: Set[Var] = set() axes: set[Var] = set()
for r in region: for r in region:
if not _is_one(r.extent): if not _is_one(r.extent):
raise ValueError("Expect elemwise block access") raise ValueError("Expect elemwise block access")
...@@ -267,7 +267,7 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: ...@@ -267,7 +267,7 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
except ValueError: except ValueError:
return None return None
traits: Dict[Var, IterTrait] = {} traits: dict[Var, IterTrait] = {}
for iter_var in block.iter_vars: for iter_var in block.iter_vars:
var = iter_var.var var = iter_var.var
kind: IterKind kind: IterKind
...@@ -308,7 +308,7 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: ...@@ -308,7 +308,7 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
def get_index_map(block: tir.Block, def get_index_map(block: tir.Block,
layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]: layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None:
"""Get index maps for the block """Get index maps for the block
Parameters Parameters
...@@ -334,8 +334,8 @@ def get_index_map(block: tir.Block, ...@@ -334,8 +334,8 @@ def get_index_map(block: tir.Block,
return None return None
A_traits, B_traits, C_traits, block_traits = traits A_traits, B_traits, C_traits, block_traits = traits
def get_ordered_axes(region: List[Range]) -> Set[Var]: def get_ordered_axes(region: list[Range]) -> set[Var]:
axes: List[Var] = [] axes: list[Var] = []
for r in region: for r in region:
if not _is_one(r.extent): if not _is_one(r.extent):
raise ValueError("Expect elemwise block access") raise ValueError("Expect elemwise block access")
...@@ -352,11 +352,11 @@ def get_index_map(block: tir.Block, ...@@ -352,11 +352,11 @@ def get_index_map(block: tir.Block,
vars = collect_vars_from_expr(var) vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars) return any(is_common_reduce(v) for v in vars)
def check_last_trait(region: List[Range]): def check_last_trait(region: list[Range]):
axes = get_ordered_axes(region) axes = get_ordered_axes(region)
return has_common_reduce(axes[-1]) return has_common_reduce(axes[-1])
def infer_layout(layout: str, region: List[Range], kind: str = "A"): def infer_layout(layout: str, region: list[Range], kind: str = "A"):
""" """
Infer the layout based on the region and the kind of buffer Infer the layout based on the region and the kind of buffer
kind: "A", "B", "C" kind: "A", "B", "C"
...@@ -409,7 +409,7 @@ def get_index_map(block: tir.Block, ...@@ -409,7 +409,7 @@ def get_index_map(block: tir.Block,
) )
def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: def get_in_out_dtypes(block: tir.Block) -> tuple[str]:
""" """
Detect In/Out data types for the given block based on the analysis if read/write buffers. Detect In/Out data types for the given block based on the analysis if read/write buffers.
""" """
...@@ -419,7 +419,7 @@ def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: ...@@ -419,7 +419,7 @@ def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
return (in_dtype, out_dtype) return (in_dtype, out_dtype)
def get_dequantize_block(sch, blocks) -> Optional[BlockRV]: def get_dequantize_block(sch, blocks) -> BlockRV | None:
# check at least two input and one output # check at least two input and one output
# at lease one input has uint dtype, and the output dtype is float # at lease one input has uint dtype, and the output dtype is float
def is_dequantize(block: BlockRV) -> bool: def is_dequantize(block: BlockRV) -> bool:
...@@ -445,8 +445,8 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: ...@@ -445,8 +445,8 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
if not isinstance(block_stmt.body.value, tir.BufferLoad): if not isinstance(block_stmt.body.value, tir.BufferLoad):
return False, False return False, False
def get_access_vars(region: List[Range]) -> List[Var]: def get_access_vars(region: list[Range]) -> list[Var]:
axes: List[Var] = [] axes: list[Var] = []
for r in region: for r in region:
if not _is_one(r.extent): if not _is_one(r.extent):
return None return None
...@@ -475,7 +475,7 @@ def is_transpose_block(block_stmt: tir.Block) -> bool: ...@@ -475,7 +475,7 @@ def is_transpose_block(block_stmt: tir.Block) -> bool:
return is_identity_or_transpose_block(block_stmt)[1] return is_identity_or_transpose_block(block_stmt)[1]
def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]): def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]):
result_blocks = [] result_blocks = []
for block in blocks: for block in blocks:
if not is_transpose_block(sch.get(block)): if not is_transpose_block(sch.get(block)):
...@@ -493,7 +493,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] ...@@ -493,7 +493,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]
def normalize_to_matmul(sch: tir.Schedule, def normalize_to_matmul(sch: tir.Schedule,
main_block: BlockRV, main_block: BlockRV,
layout: Optional[List[str]] = None) -> Optional[tir.Schedule]: layout: list[str] | None = None) -> tir.Schedule | None:
if layout is None: if layout is None:
layout = ["n", "t", "n"] layout = ["n", "t", "n"]
block_stmt = sch.get(main_block) block_stmt = sch.get(main_block)
...@@ -521,10 +521,10 @@ def normalize_to_matmul(sch: tir.Schedule, ...@@ -521,10 +521,10 @@ def normalize_to_matmul(sch: tir.Schedule,
def get_tensorized_func_and_tags( def get_tensorized_func_and_tags(
func: tir.PrimFunc, func: tir.PrimFunc,
target: Target, target: Target,
layout: Optional[List[str]] = None, layout: list[str] | None = None,
skip_normalize: bool = False, skip_normalize: bool = False,
allow_gemv: bool = False, allow_gemv: bool = False,
) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: ) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]:
""" """
transform function to matmul if necessary (e.g. transform conv2d with im2col) transform function to matmul if necessary (e.g. transform conv2d with im2col)
""" """
...@@ -554,9 +554,8 @@ def get_tensorized_func_and_tags( ...@@ -554,9 +554,8 @@ def get_tensorized_func_and_tags(
sm_version = arch.replace("sm_", "") sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1 return int(sm_version) if sm_version.isdigit() else -1
def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool | dict:
target: Target) -> Union[bool, Dict]: tags: dict[str, list[int] | int] = {}
tags: Dict[str, Union[List[int], int]] = {}
block_stmt = sch.get(block) block_stmt = sch.get(block)
# Nvidia Only Support Tensor Core for # Nvidia Only Support Tensor Core for
...@@ -584,8 +583,8 @@ def get_tensorized_func_and_tags( ...@@ -584,8 +583,8 @@ def get_tensorized_func_and_tags(
tags["use_async_copy"] = True tags["use_async_copy"] = True
# analysis intrin information # analysis intrin information
def get_ordered_axes(region: List[Range]) -> Set[Var]: def get_ordered_axes(region: list[Range]) -> set[Var]:
axes: List[Var] = [] axes: list[Var] = []
for r in region: for r in region:
if not _is_one(r.extent): if not _is_one(r.extent):
raise ValueError("Expect elemwise block access") raise ValueError("Expect elemwise block access")
...@@ -602,7 +601,7 @@ def get_tensorized_func_and_tags( ...@@ -602,7 +601,7 @@ def get_tensorized_func_and_tags(
vars = collect_vars_from_expr(var) vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars) return any(is_common_reduce(v) for v in vars)
def check_last_trait(region: List[Range]): def check_last_trait(region: list[Range]):
axes = get_ordered_axes(region) axes = get_ordered_axes(region)
return has_common_reduce(axes[-1]) return has_common_reduce(axes[-1])
......
...@@ -17,7 +17,7 @@ class Block: ...@@ -17,7 +17,7 @@ class Block:
self.end = max(self.end, other.end) self.end = max(self.end, other.end)
def __repr__(self) -> str: def __repr__(self) -> str:
return "<Block offset={} size={}>".format(self.start, self.size()) return f"<Block offset={self.start} size={self.size()}>"
class BestFit: class BestFit:
......
"""Hint definition for schedule""" """Hint definition for schedule"""
from __future__ import annotations
from tvm import DataType from tvm import DataType
from typing import Dict, List, Tuple
from . import PrimFuncNode from . import PrimFuncNode
import numpy as np import numpy as np
from .rasterization import * from .rasterization import *
...@@ -13,17 +13,17 @@ class TensorCoreExtraConfig: ...@@ -13,17 +13,17 @@ class TensorCoreExtraConfig:
def __init__( def __init__(
self, self,
AS_shape: Tuple[int], AS_shape: tuple[int],
BS_shape: Tuple[int], BS_shape: tuple[int],
AF_shape: Tuple[int], AF_shape: tuple[int],
BF_shape: Tuple[int], BF_shape: tuple[int],
tc_axis: Tuple[int], tc_axis: tuple[int],
) -> None: ) -> None:
self.AS_shape: Tuple[int] = AS_shape self.AS_shape: tuple[int] = AS_shape
self.BS_shape: Tuple[int] = BS_shape self.BS_shape: tuple[int] = BS_shape
self.AF_shape: Tuple[int] = AF_shape self.AF_shape: tuple[int] = AF_shape
self.BF_shape: Tuple[int] = BF_shape self.BF_shape: tuple[int] = BF_shape
self.tc_axis: Tuple[int] = tc_axis self.tc_axis: tuple[int] = tc_axis
class Stride: class Stride:
...@@ -45,7 +45,7 @@ class Stride: ...@@ -45,7 +45,7 @@ class Stride:
def stride(self) -> int: def stride(self) -> int:
return self._stride return self._stride
def compute_strides_from_shape(self, shape: List[int]) -> List[int]: def compute_strides_from_shape(self, shape: list[int]) -> list[int]:
ndim = len(shape) ndim = len(shape)
strides = [1 for _ in shape] strides = [1 for _ in shape]
for i in range(ndim - 2, -1, -1): for i in range(ndim - 2, -1, -1):
...@@ -55,7 +55,7 @@ class Stride: ...@@ -55,7 +55,7 @@ class Stride:
strides[i] = int(strides[i + 1] * shape[i + 1]) strides[i] = int(strides[i + 1] * shape[i + 1])
return strides return strides
def compute_elements_from_shape(self, shape: List[int]) -> int: def compute_elements_from_shape(self, shape: list[int]) -> int:
original_shape = np.prod(shape) original_shape = np.prod(shape)
if not self.is_valid(): if not self.is_valid():
strided_elem = original_shape strided_elem = original_shape
...@@ -94,10 +94,10 @@ class TileDict: ...@@ -94,10 +94,10 @@ class TileDict:
self.grid_size = -1 self.grid_size = -1
self.valid = True self.valid = True
def get_tile(self, func) -> List[int]: def get_tile(self, func) -> list[int]:
return self.tile_map[func] return self.tile_map[func]
def get_rstep(self, node) -> Dict[str, int]: def get_rstep(self, node) -> dict[str, int]:
return self.rstep_map[node] return self.rstep_map[node]
def __hash__(self) -> int: def __hash__(self) -> int:
...@@ -147,7 +147,7 @@ class IntrinInfo: ...@@ -147,7 +147,7 @@ class IntrinInfo:
return self.weight_transform_kind >= 1 return self.weight_transform_kind >= 1
class Hint(object): class Hint:
""" """
Central configuration class for managing various parameters of computational tasks. Central configuration class for managing various parameters of computational tasks.
""" """
...@@ -178,15 +178,15 @@ class Hint(object): ...@@ -178,15 +178,15 @@ class Hint(object):
# Experimental # Experimental
self._raxis_order = [] self._raxis_order = []
self._step = [] self._step = []
self.vectorize: Dict[str, int] = {} self.vectorize: dict[str, int] = {}
self.pipeline_stage = 1 self.pipeline_stage = 1
self.use_async = False self.use_async = False
self.opt_shapes: Dict[str, int] = {} self.opt_shapes: dict[str, int] = {}
self.intrin_info = IntrinInfo("float16", "float16", True) self.intrin_info = IntrinInfo("float16", "float16", True)
self.shared_scope: str = "shared" self.shared_scope: str = "shared"
self.pass_context: Dict = {} self.pass_context: dict = {}
def to_dict(self) -> Dict: def to_dict(self) -> dict:
dic = {} dic = {}
dic["block"] = self.block dic["block"] = self.block
if self.use_tc: if self.use_tc:
...@@ -218,7 +218,7 @@ class Hint(object): ...@@ -218,7 +218,7 @@ class Hint(object):
return dic return dic
@classmethod @classmethod
def from_dict(cls, dic: Dict) -> "Hint": def from_dict(cls, dic: dict) -> Hint:
hint = cls() hint = cls()
for k, v in dic.items(): for k, v in dic.items():
setattr(hint, k, v) setattr(hint, k, v)
...@@ -231,13 +231,13 @@ class Hint(object): ...@@ -231,13 +231,13 @@ class Hint(object):
return self return self
@property @property
def raxis_order(self) -> List[int]: def raxis_order(self) -> list[int]:
if self._raxis_order != []: if self._raxis_order != []:
return self._raxis_order return self._raxis_order
return list(range(len(self.rstep))) return list(range(len(self.rstep)))
@property @property
def step(self) -> List[int]: def step(self) -> list[int]:
if self._step != []: if self._step != []:
return self._step return self._step
return [1 for _ in self.block] return [1 for _ in self.block]
......
"""PrimFunc Wrapper and Block information Analaysis""" """PrimFunc Wrapper and Block information Analaysis"""
from __future__ import annotations
import tvm import tvm
from tvm import tir from tvm import tir
from tvm.tir import IterVar, PrimFunc from tvm.tir import IterVar, PrimFunc
from typing import Any, Dict, List, Tuple, Optional from typing import Any
from tvm.tir.schedule.schedule import BlockRV from tvm.tir.schedule.schedule import BlockRV
import numpy as np import numpy as np
import functools import functools
...@@ -29,11 +30,11 @@ def pre_order_traverse(block_analyzer, blocks, func): ...@@ -29,11 +30,11 @@ def pre_order_traverse(block_analyzer, blocks, func):
_traverse(block) _traverse(block)
class BlockAnalyzer(object): class BlockAnalyzer:
def __init__(self, sch) -> None: def __init__(self, sch) -> None:
self.sch: tir.Schedule = sch self.sch: tir.Schedule = sch
self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch) self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch)
def get_block_name(self, block: BlockRV) -> str: def get_block_name(self, block: BlockRV) -> str:
return self.sch.get(block).name_hint return self.sch.get(block).name_hint
...@@ -44,7 +45,7 @@ class BlockAnalyzer(object): ...@@ -44,7 +45,7 @@ class BlockAnalyzer(object):
return block_info return block_info
return None return None
def get_spatial_axis(self, block: BlockRV) -> List[IterVar]: def get_spatial_axis(self, block: BlockRV) -> list[IterVar]:
block_info = self.get_block_info(block) block_info = self.get_block_info(block)
axis = [] axis = []
for iter in block_info.iters: for iter in block_info.iters:
...@@ -52,7 +53,7 @@ class BlockAnalyzer(object): ...@@ -52,7 +53,7 @@ class BlockAnalyzer(object):
axis.append(iter) axis.append(iter)
return axis return axis
def get_reduce_axis(self, block: BlockRV) -> List[IterVar]: def get_reduce_axis(self, block: BlockRV) -> list[IterVar]:
block_info = self.get_block_info(block) block_info = self.get_block_info(block)
raxis = [] raxis = []
for iter in block_info.iters: for iter in block_info.iters:
...@@ -60,39 +61,39 @@ class BlockAnalyzer(object): ...@@ -60,39 +61,39 @@ class BlockAnalyzer(object):
raxis.append(iter) raxis.append(iter)
return raxis return raxis
def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]: def get_input_buffers(self, block: BlockRV) -> list[tir.Buffer]:
buffers = [] buffers = []
for read in self.sch.get(block).reads: for read in self.sch.get(block).reads:
buffers.append(read.buffer) buffers.append(read.buffer)
return buffers return buffers
def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]: def get_output_buffers(self, block: BlockRV) -> list[tir.Buffer]:
buffers = [] buffers = []
for write in self.sch.get(block).writes: for write in self.sch.get(block).writes:
buffers.append(write.buffer) buffers.append(write.buffer)
return buffers return buffers
def get_buffers(self, block: BlockRV) -> List[tir.Buffer]: def get_buffers(self, block: BlockRV) -> list[tir.Buffer]:
return self.get_input_buffers(block) + self.get_output_buffers(block) return self.get_input_buffers(block) + self.get_output_buffers(block)
def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]: def get_producer_blocks(self, block: BlockRV) -> list[BlockRV]:
return self.sch.get_producers(block) return self.sch.get_producers(block)
def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]: def get_consumer_blocks(self, block: BlockRV) -> list[BlockRV]:
return self.sch.get_consumers(block) return self.sch.get_consumers(block)
@dataclass @dataclass
class Edge: class Edge:
src_node: 'Node' src_node: Node
dst_node: 'Node' dst_node: Node
src_id: int src_id: int
dst_id: int dst_id: int
class Node(object): class Node:
def __init__(self, tags: Optional[Dict] = None, name: str = "Node") -> None: def __init__(self, tags: dict | None = None, name: str = "Node") -> None:
self.name = name self.name = name
if tags is None: if tags is None:
tags = {} tags = {}
...@@ -100,10 +101,10 @@ class Node(object): ...@@ -100,10 +101,10 @@ class Node(object):
self._in_edges = [] self._in_edges = []
self._shapes = [] self._shapes = []
self._dtypes = [] self._dtypes = []
self._tag: Dict = {} self._tag: dict = {}
self.update_tags(tags) self.update_tags(tags)
def update_tags(self, tags: Dict) -> None: def update_tags(self, tags: dict) -> None:
for tag in tags: for tag in tags:
self.add_tag(tag, tags[tag]) self.add_tag(tag, tags[tag])
...@@ -125,11 +126,11 @@ class Node(object): ...@@ -125,11 +126,11 @@ class Node(object):
return False return False
@property @property
def inputs(self) -> List[Edge]: def inputs(self) -> list[Edge]:
return self._in_edges return self._in_edges
@property @property
def outputs(self) -> List[Edge]: def outputs(self) -> list[Edge]:
return self._out_edges return self._out_edges
def set_inputs(self, i: int, edge: Edge): def set_inputs(self, i: int, edge: Edge):
...@@ -153,10 +154,10 @@ class Node(object): ...@@ -153,10 +154,10 @@ class Node(object):
assert self._dtypes[id] == dtype, (self._dtypes, dtype) assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype self._dtypes[id] = dtype
def get_shape(self, id: int = 0) -> List[int]: def get_shape(self, id: int = 0) -> list[int]:
return self._shapes[id] return self._shapes[id]
def set_shape(self, shape: List[int], id=0, overwrite=False) -> None: def set_shape(self, shape: list[int], id=0, overwrite=False) -> None:
if len(self._shapes) <= id: if len(self._shapes) <= id:
self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)]) self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)])
# elif self._shapes[id] is not None and not overwrite: # elif self._shapes[id] is not None and not overwrite:
...@@ -191,15 +192,15 @@ class PrimFuncNode(Node): ...@@ -191,15 +192,15 @@ class PrimFuncNode(Node):
def __init__(self, def __init__(self,
prim_func: PrimFunc, prim_func: PrimFunc,
tags: Optional[Dict] = None, tags: dict | None = None,
name: str = "PrimFuncNode") -> None: name: str = "PrimFuncNode") -> None:
super().__init__(tags, name=name) super().__init__(tags, name=name)
self.prim_func = self._specialize_func(prim_func) self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func) self.sch: tir.Schedule = tir.Schedule(self.prim_func)
self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch) self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch)
self.schedule_stages: List[BlockRV] = [] self.schedule_stages: list[BlockRV] = []
self.blocks: List[BlockRV] = [] self.blocks: list[BlockRV] = []
self.output_blocks: List[BlockRV] = None self.output_blocks: list[BlockRV] = None
self.reduction_block: BlockRV = None self.reduction_block: BlockRV = None
self.raxis = [] self.raxis = []
self.input_buffers = [] self.input_buffers = []
...@@ -219,7 +220,7 @@ class PrimFuncNode(Node): ...@@ -219,7 +220,7 @@ class PrimFuncNode(Node):
self.set_dtype(tvm.DataType(buffer.dtype), output_id) self.set_dtype(tvm.DataType(buffer.dtype), output_id)
def _assign_placeholder_node(self): def _assign_placeholder_node(self):
inputs: List[Node] = [] inputs: list[Node] = []
for buffer in self.input_buffers: for buffer in self.input_buffers:
inputs.append(PlaceHolderNode(buffer.name)) inputs.append(PlaceHolderNode(buffer.name))
...@@ -301,8 +302,8 @@ class PrimFuncNode(Node): ...@@ -301,8 +302,8 @@ class PrimFuncNode(Node):
else: else:
return value return value
@functools.lru_cache() @functools.lru_cache
def get_space_dim(self) -> List[int]: def get_space_dim(self) -> list[int]:
dim_size = [] dim_size = []
if self.reduction_block: if self.reduction_block:
block_info = self.block_analyzer.get_block_info(self.reduction_block) block_info = self.block_analyzer.get_block_info(self.reduction_block)
...@@ -333,7 +334,7 @@ class PrimFuncNode(Node): ...@@ -333,7 +334,7 @@ class PrimFuncNode(Node):
def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType:
return tvm.DataType(buffer.dtype) return tvm.DataType(buffer.dtype)
def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): def propagate(self, tile, rstep: dict | None = None, targets=None):
if rstep is None: if rstep is None:
rstep = {} rstep = {}
shape = { shape = {
...@@ -343,7 +344,7 @@ class PrimFuncNode(Node): ...@@ -343,7 +344,7 @@ class PrimFuncNode(Node):
} }
return self.ana.infer(shape, rstep, targets) return self.ana.infer(shape, rstep, targets)
def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: def propagate_inputs(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None: if rstep is None:
rstep = {} rstep = {}
read_idx_offset = len(self.input_buffers) read_idx_offset = len(self.input_buffers)
...@@ -363,7 +364,7 @@ class PrimFuncNode(Node): ...@@ -363,7 +364,7 @@ class PrimFuncNode(Node):
return results return results
# Propagate inputs only on reduction block # Propagate inputs only on reduction block
def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: def propagate_inputs_on_reduction(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None: if rstep is None:
rstep = {} rstep = {}
reduction_block = self.reduction_block reduction_block = self.reduction_block
...@@ -386,7 +387,7 @@ class PrimFuncNode(Node): ...@@ -386,7 +387,7 @@ class PrimFuncNode(Node):
results.append(trimmed_shape) results.append(trimmed_shape)
return results return results
def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: def propagate_outputs(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None: if rstep is None:
rstep = {} rstep = {}
read_idx_offset = len(self.input_buffers) read_idx_offset = len(self.input_buffers)
...@@ -399,9 +400,7 @@ class PrimFuncNode(Node): ...@@ -399,9 +400,7 @@ class PrimFuncNode(Node):
results.append(trimmed_shape) results.append(trimmed_shape)
return results return results
def propagate_reduction_inputs(self, def propagate_reduction_inputs(self, shape, rstep: dict | None = None) -> dict[str, list[int]]:
shape,
rstep: Optional[Dict] = None) -> Dict[str, List[int]]:
if rstep is None: if rstep is None:
rstep = {} rstep = {}
if self.reduction_block is None: if self.reduction_block is None:
...@@ -418,8 +417,8 @@ class PrimFuncNode(Node): ...@@ -418,8 +417,8 @@ class PrimFuncNode(Node):
for b in self.block_analyzer.get_input_buffers(self.reduction_block) for b in self.block_analyzer.get_input_buffers(self.reduction_block)
} }
@functools.lru_cache() @functools.lru_cache
def infer_tensorcore_axis(self) -> Tuple[int]: def infer_tensorcore_axis(self) -> tuple[int]:
# axis is fixed for one expression, so only inference and cached # axis is fixed for one expression, so only inference and cached
assert self.get_tag("tensorcore_config") assert self.get_tag("tensorcore_config")
...@@ -461,7 +460,7 @@ class PrimFuncNode(Node): ...@@ -461,7 +460,7 @@ class PrimFuncNode(Node):
tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n) tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n)
return tc_axis return tc_axis
def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int: def footprint(self, shape, rstep, stride_map: dict | None = None) -> int:
if stride_map is None: if stride_map is None:
stride_map = {} stride_map = {}
result = 0 result = 0
...@@ -510,7 +509,7 @@ class PrimFuncNode(Node): ...@@ -510,7 +509,7 @@ class PrimFuncNode(Node):
result += buffer_len result += buffer_len
return result, cached_tensor return result, cached_tensor
def get_input_buffers(self) -> List[tir.Buffer]: def get_input_buffers(self) -> list[tir.Buffer]:
return self.block_analyzer.input_buffers return self.block_analyzer.input_buffers
...@@ -537,7 +536,7 @@ class OutputNode(Node): ...@@ -537,7 +536,7 @@ class OutputNode(Node):
return "output" return "output"
def topo_order(list_of_nodes) -> List[Node]: def topo_order(list_of_nodes) -> list[Node]:
input_ready_count = {node: len(node.inputs) for node in list_of_nodes} input_ready_count = {node: len(node.inputs) for node in list_of_nodes}
ready = list(filter(lambda node: input_ready_count[node] == 0, list_of_nodes)) ready = list(filter(lambda node: input_ready_count[node] == 0, list_of_nodes))
output_list = [] output_list = []
...@@ -557,7 +556,7 @@ def topo_order(list_of_nodes) -> List[Node]: ...@@ -557,7 +556,7 @@ def topo_order(list_of_nodes) -> List[Node]:
return output_list return output_list
def find_topo_sort_priority(output_node_list) -> List[Node]: def find_topo_sort_priority(output_node_list) -> list[Node]:
import sys import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
...@@ -591,7 +590,7 @@ def find_topo_sort_priority(output_node_list) -> List[Node]: ...@@ -591,7 +590,7 @@ def find_topo_sort_priority(output_node_list) -> List[Node]:
return topo_order return topo_order
def find_topo_sort(output_node_list) -> List[Node]: def find_topo_sort(output_node_list) -> list[Node]:
def topo_sort_dfs(node, visited, topo_order): def topo_sort_dfs(node, visited, topo_order):
if node in visited: if node in visited:
......
from typing import List from __future__ import annotations
import numpy as np import numpy as np
def get_all_factors(n: int) -> List[int]: def get_all_factors(n: int) -> list[int]:
# Calculate the square root of n and round it up to the nearest integer # Calculate the square root of n and round it up to the nearest integer
n0 = int(np.ceil(np.sqrt(n))) n0 = int(np.ceil(np.sqrt(n)))
...@@ -16,7 +16,7 @@ def get_all_factors(n: int) -> List[int]: ...@@ -16,7 +16,7 @@ def get_all_factors(n: int) -> List[int]:
return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])] return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])]
def factorize(n: int) -> List[int]: def factorize(n: int) -> list[int]:
i = 2 # Start with the smallest prime number i = 2 # Start with the smallest prime number
result = [] result = []
...@@ -30,7 +30,7 @@ def factorize(n: int) -> List[int]: ...@@ -30,7 +30,7 @@ def factorize(n: int) -> List[int]:
return result return result
def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int: def coalesced_factor(subtensor: list[int], tensor: list[int]) -> int:
# If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension # If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension
if subtensor[-1] != tensor[-1] or len(subtensor) == 1: if subtensor[-1] != tensor[-1] or len(subtensor) == 1:
return subtensor[-1] return subtensor[-1]
...@@ -39,7 +39,7 @@ def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int: ...@@ -39,7 +39,7 @@ def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int:
return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1]) return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1])
def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int: def coalesced_tensor_shape(subtensor: list[int], tensor: list[int], transaction_size: int) -> int:
# Calculate the total number of elements in the subtensor # Calculate the total number of elements in the subtensor
bytes = int(np.prod(subtensor)) bytes = int(np.prod(subtensor))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment