Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
# -*- coding: utf-8 -*-
# General information about the project.
project = "Tile Language <br>"
author = "Tile Lang Contributors"
copyright = "2025-2025, %s" % author
copyright = f"2025-2025, {author}"
# Version information.
with open("../VERSION", "r") as f:
with open("../VERSION") as f:
version = f.read().strip()
release = version
......
......@@ -87,6 +87,17 @@ target-version = "py38"
line-length = 100
output-format = "full"
exclude = [
"3rdparty",
"examples/deepseek_v32/inference",
]
[tool.ruff.lint.per-file-ignores]
# Do not upgrade type hint in testing and examples.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
"testing/**.py" = ["UP", "FA"]
"examples/**.py" = ["UP", "FA"]
[tool.ruff.lint]
select = [
# pycodestyle
......@@ -94,7 +105,7 @@ select = [
# Pyflakes
"F",
# pyupgrade
# "UP",
"UP", "FA",
# flake8-bugbear
"B",
# flake8-simplify
......@@ -115,6 +126,8 @@ ignore = [
"SIM108",
# key in dict.keys()
"SIM118",
# open file w.o. ctx manager
"SIM115",
# memory leaks
"B019",
# zip without explicit strict
......@@ -122,9 +135,6 @@ ignore = [
# No such file or directory
"E902",
]
[tool.ruff.lint.per-file-ignores]
"3rdparty/**/*" = ["ALL"]
"examples/deepseek_v32/inference/**/*" = ["ALL"]
[tool.pytest.ini_options]
verbosity_assertions = 3
......
from __future__ import annotations
import threading
from typing import List, Any, Optional
from typing import Any
# Use thread local to store the stack
# This is to avoid the cross-thread interference
......@@ -87,7 +88,7 @@ class AutotuneInputsCapture:
__slots__ = ("tensors")
def __init__(self, tensors: List[Any]):
def __init__(self, tensors: list[Any]):
self.tensors = tensors
def __enter__(self) -> None:
......@@ -118,7 +119,7 @@ def set_autotune_inputs(*args) -> AutotuneInputsCapture:
return AutotuneInputsCapture(tensors)
def get_autotune_inputs() -> Optional[List[Any]]:
def get_autotune_inputs() -> list[Any] | None:
"""
Get the current autotune inputs from the stack.
"""
......
"""The auto-tune parameters.
"""
from __future__ import annotations
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
from typing import Callable, List, Literal, Any, Optional, Union, Dict
from typing import Callable, Literal, Any
from dataclasses import dataclass
from pathlib import Path
......@@ -40,12 +41,12 @@ class CompileArgs:
Refer to `tilelang.PassConfigKey` for supported options.
"""
out_idx: Optional[Union[List[int], int]] = None
out_idx: list[int] | int | None = None
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: Union[str, Target] = None
target_host: str | Target = None
verbose: bool = False
pass_configs: Optional[Dict[str, Any]] = None
pass_configs: dict[str, Any] | None = None
def compile_program(self, program: PrimFunc):
return tilelang.compile(
......@@ -135,12 +136,12 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: Optional[float] = None
config: Optional[dict] = None
ref_latency: Optional[float] = None
libcode: Optional[str] = None
func: Optional[Callable] = None
kernel: Optional[Callable] = None
latency: float | None = None
config: dict | None = None
ref_latency: float | None = None
libcode: str | None = None
func: Callable | None = None
kernel: Callable | None = None
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
"""
......@@ -204,9 +205,9 @@ class AutotuneResult:
def _load_kernel_from_disk(
self,
cache_path: Path,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: Optional[Union[List[int], int]] = None,
target: str | Target = "auto",
target_host: str | Target = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
......@@ -232,14 +233,14 @@ class AutotuneResult:
if not os.path.exists(cache_path):
return None
kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
kernel_global_source: str | None = None
kernel_params: list[KernelParam] | None = None
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
with open(wrapped_kernel_path, "r") as f:
with open(wrapped_kernel_path) as f:
kernel_global_source = f.read()
except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
......@@ -300,7 +301,7 @@ class AutotuneResult:
self._save_kernel_to_disk(path, self.kernel)
@classmethod
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult':
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult:
if not os.path.exists(path):
return None
......@@ -308,7 +309,7 @@ class AutotuneResult:
# load best config
if verbose:
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
with open(path / BEST_CONFIG_PATH, "r") as f:
with open(path / BEST_CONFIG_PATH) as f:
config = json.load(f)
# load function
......@@ -320,7 +321,7 @@ class AutotuneResult:
# load latency
if verbose:
logger.debug(f"Loading latency from file: {path / LATENCY_PATH}")
with open(path / LATENCY_PATH, "r") as f:
with open(path / LATENCY_PATH) as f:
latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"]
......
......@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
from __future__ import annotations
import tilelang
from tilelang import tvm as tvm
......@@ -10,7 +11,7 @@ from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
from typing import (Callable, Literal, Any, overload)
from tqdm import tqdm
import logging
import functools
......@@ -103,8 +104,8 @@ class AutoTuner:
compile_args = CompileArgs()
profile_args = ProfileArgs()
_kernel_parameters: Optional[Tuple[str, ...]] = None
_function_parameters: Optional[Dict[str, Any]] = None
_kernel_parameters: tuple[str, ...] | None = None
_function_parameters: dict[str, Any] | None = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
......@@ -131,12 +132,12 @@ class AutoTuner:
return cls(kernel, configs)
def set_compile_args(self,
out_idx: Union[List[int], int, None] = None,
out_idx: list[int] | int | None = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target_host: Union[str, Target] = None,
target_host: str | Target = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
pass_configs: dict[str, Any] | None = None):
"""Set compilation arguments for the auto-tuner.
Args:
......@@ -223,12 +224,12 @@ class AutoTuner:
return self
def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]):
# for cache key generation
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process.
"""
......@@ -307,8 +308,8 @@ class AutoTuner:
return result
best_latency: float = 1e8
best_config: Optional[Dict[str, Any]] = None
best_kernel: Optional[tilelang.JITKernel] = None
best_config: dict[str, Any] | None = None
best_kernel: tilelang.JITKernel | None = None
def _compile(**config_arg) -> tilelang.JITKernel:
compile_args = self.compile_args
......@@ -591,7 +592,7 @@ class _AutoTunerImplementation:
warmup: int = 25
rep: int = 100
timeout: int = 100
configs: Union[Dict, Callable] = None
configs: dict | Callable = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
......@@ -603,7 +604,7 @@ class _AutoTunerImplementation:
cache_input_tensors: bool = False
def __init__(self,
configs: Union[Dict, Callable],
configs: dict | Callable,
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
......@@ -653,12 +654,12 @@ class _AutoTunerImplementation:
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
...
@overload
......@@ -720,9 +721,9 @@ class _AutoTunerImplementation:
def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
func: Callable[_P, _RProg] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
configs: Union[Dict, Callable],
configs: dict | Callable,
# profile arguments
warmup: int = 25,
rep: int = 100,
......
"""The cache utils with class and database persistence - Init file"""
from __future__ import annotations
from typing import List, Union, Literal, Optional
from typing import Literal
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
......@@ -13,14 +14,14 @@ _kernel_cache_instance = KernelCache()
def cached(
func: PrimFunc = None,
out_idx: List[int] = None,
out_idx: list[int] = None,
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None,
compile_flags: Optional[Union[List[str], str]] = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython",
verbose: bool | None = False,
pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels (using KernelCache class).
......
"""The cache utils with class and database persistence - KernelCache Class"""
from __future__ import annotations
import json
import logging
......@@ -7,7 +8,7 @@ import shutil
import threading
import uuid
from hashlib import sha256
from typing import Callable, List, Literal, Optional, Union
from typing import Callable, Literal
import cloudpickle
from tvm.target import Target
......@@ -67,13 +68,13 @@ class KernelCache:
def _generate_key(
self,
func: Callable,
out_idx: List[int],
out_idx: list[int],
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
target: str | Target = "auto",
target_host: str | Target = None,
pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None,
compile_flags: list[str] | str | None = None,
) -> str:
"""
Generates a unique hash key for caching compiled kernels.
......@@ -112,14 +113,14 @@ class KernelCache:
def cached(
self,
func: PrimFunc = None,
out_idx: List[int] = None,
out_idx: list[int] = None,
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None,
compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.
......@@ -322,15 +323,15 @@ class KernelCache:
def _load_kernel_from_disk(
self,
key: str,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
target: str | Target = "auto",
target_host: str | Target = None,
out_idx: list[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
pass_configs: dict = None,
compile_flags: Optional[Union[List[str], str]] = None,
compile_flags: list[str] | str | None = None,
func: Callable = None,
verbose: bool = False,
) -> Optional[JITKernel]:
) -> JITKernel | None:
"""
Loads a previously compiled kernel from disk cache.
......@@ -355,15 +356,15 @@ class KernelCache:
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
return None
kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
kernel_global_source: str | None = None
kernel_params: list[KernelParam] | None = None
# Load the kernel source file (optional)
try:
if verbose:
self.logger.debug(
f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
with open(wrapped_kernel_path, "r") as f:
with open(wrapped_kernel_path) as f:
kernel_global_source = f.read()
except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")
......
"""Analysis on TIR blocks, loops and functions."""
from typing import List, Optional, Set, Union
from __future__ import annotations
from typing_extensions import Literal
from tvm import ir, tir, DataType
......@@ -31,7 +31,7 @@ class IterInfo:
self.loop_rv = loop_rv
@property
def dom(self) -> Union[int, tir.PrimExpr]:
def dom(self) -> int | tir.PrimExpr:
"""The iteration domain of the loop."""
return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom
......@@ -46,14 +46,14 @@ class BlockInfo:
"""Information about a TIR block."""
name: str
iters: List[IterInfo]
iters: list[IterInfo]
block_rv: tir.schedule.BlockRV
_reduction_block: bool
def __init__(
self,
name: str,
iters: List[IterInfo],
iters: list[IterInfo],
block_rv: tir.schedule.BlockRV,
reduction_block: bool = False,
):
......@@ -63,7 +63,7 @@ class BlockInfo:
self.iters = iters
self._reduction_block = reduction_block
def dom(self) -> List[Union[int, tir.PrimExpr]]:
def dom(self) -> list[int | tir.PrimExpr]:
"""The iteration domain of the block."""
return [i.dom for i in self.iters]
......@@ -118,7 +118,7 @@ class BlockInfo:
_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc")
def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
"""Normalize the primfunc to normal form"""
try:
result = _normalize_prim_func(sch)
......@@ -133,7 +133,7 @@ def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
tir.IterVar.CommReduce: "R",
}.get(i.iter_type, "O")
blocks: List[BlockInfo] = []
blocks: list[BlockInfo] = []
for block, loops, iters, is_reduction in zip(*result):
blocks.append(
BlockInfo(
......@@ -203,7 +203,7 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
def collect_block_iter_vars_used_in_access_region(block: tir.Block,
region: List[ir.Range]) -> Set[tir.Var]:
region: list[ir.Range]) -> set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set()
for expr in region:
......@@ -214,7 +214,7 @@ def collect_block_iter_vars_used_in_access_region(block: tir.Block,
return tir_vars
def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]:
def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> set[tir.Var]:
"""Collect the variables used in the PrimExpr."""
tir_vars = set()
......@@ -259,7 +259,7 @@ def is_broadcast_epilogue(
def get_reduction_blocks(sch: tir.Schedule,
blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]:
blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block)
......@@ -286,7 +286,7 @@ def get_reduction_blocks(sch: tir.Schedule,
def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int:
# gpu memory prefer 128 bits coalesced access (e.g. four banks)
# 128 bits
buffers: List[tir.Buffer] = []
buffers: list[tir.Buffer] = []
for read in block_stmt.reads:
buffers.append(read.buffer)
for write in block_stmt.writes:
......
from __future__ import annotations
from .arch_base import TileDevice
from .cuda import *
from .cpu import *
from .cdna import *
from .metal import *
from typing import Union
from tvm.target import Target
import torch
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
def get_arch(target: str | Target = "cuda") -> TileDevice:
if isinstance(target, str):
target = Target(target)
......
from typing import List
from __future__ import annotations
class TileDevice:
......@@ -14,12 +14,12 @@ class TileDevice:
0 # The size of a warp, a group of threads that execute instructions in lockstep
)
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions
self.transaction_size: List[int] = [
self.transaction_size: list[int] = [
0,
0,
] # The size of memory transactions, typically in bytes
self.max_smem_usage: int = 0 # The maximum shared memory usage allowed
self.bandwidth: List[int] = [
self.bandwidth: list[int] = [
0,
0,
] # Bandwidth specifications, possibly including peak and sustained rates
......@@ -29,9 +29,9 @@ class TileDevice:
)
self.l2_cache_size_bytes: int = 0
# the number of transaction size in bytes
self.transaction_size: List[int] = [0, 0] # in bytes
self.transaction_size: list[int] = [0, 0] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
self.bandwidth: List[int] = [0, 0]
self.bandwidth: list[int] = [0, 0]
def get_avaliable_tensorintrin_shapes(self):
raise NotImplementedError()
from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Union
def is_cdna_arch(arch: TileDevice) -> bool:
......@@ -10,7 +10,7 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class CDNA(TileDevice):
def __init__(self, target: Union[Target, str]):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
......@@ -27,9 +27,9 @@ class CDNA(TileDevice):
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
self.transaction_size: List[int] = [32, 128] # in bytes
self.transaction_size: list[int] = [32, 128] # in bytes
self.bandwidth: List[int] = [1300, 14000]
self.bandwidth: list[int] = [1300, 14000]
__all__ = [
......
from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Union
from .driver import cuda_driver
......@@ -91,21 +91,21 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
raise ValueError(f"Unsupported architecture: {arch}")
class TensorInstruction(object):
class TensorInstruction:
def __init__(
self,
name: str,
shape: List[int],
shape: list[int],
):
self.name: str = name
# only hold the shape of M and N
self.shape: List[int] = shape
self.shape: list[int] = shape
class CUDA(TileDevice):
def __init__(self, target: Union[Target, str]):
def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
......@@ -126,15 +126,15 @@ class CUDA(TileDevice):
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
# the number of transaction size in bytes
self.transaction_size: List[int] = [32, 128] # in bytes
self.transaction_size: list[int] = [32, 128] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
# TODO(lei): find some way to get the real bandwidth
# However, the ratio of bandwidth between different devices can
# be similar. The bandwidth can work for another devices as well.
self.bandwidth: List[int] = [750, 12080]
self.bandwidth: list[int] = [750, 12080]
# get the available tensor instructions during runtime to avoid
# the dependency of the tensor intrinsics registration
self.available_tensor_instructions: List[TensorInstruction] = None
self.available_tensor_instructions: list[TensorInstruction] = None
def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (
......
from __future__ import annotations
import ctypes
import sys
from typing import Optional
class cudaDeviceProp(ctypes.Structure):
......@@ -77,7 +77,7 @@ class cudaDeviceProp(ctypes.Structure):
]
def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]:
def get_cuda_device_properties(device_id: int = 0) -> cudaDeviceProp | None:
if sys.platform == "win32":
libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll")
......@@ -95,7 +95,7 @@ def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]:
raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}")
def get_device_name(device_id: int = 0) -> Optional[str]:
def get_device_name(device_id: int = 0) -> str | None:
prop = get_cuda_device_properties(device_id)
if prop:
return prop.name.decode()
......@@ -103,7 +103,7 @@ def get_device_name(device_id: int = 0) -> Optional[str]:
raise RuntimeError("Failed to get device properties.")
def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> Optional[int]:
def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None:
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id)
if prop:
......@@ -143,7 +143,7 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int:
return None
def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> Optional[int]:
def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> int | None:
"""
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
......
from __future__ import annotations
from tvm.target import Target
from .arch_base import TileDevice
......
......@@ -19,7 +19,8 @@
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
from typing import Callable, List
from __future__ import annotations
from typing import Callable
from tvm import tir
from .utils import retrieve_func_from_module
......@@ -28,7 +29,7 @@ from .analysis import BlockInfo
def get_block(
sch: tir.Schedule,
blocks: List[BlockInfo],
blocks: list[BlockInfo],
name: str,
):
"""Get the target block from a schedule.
......@@ -56,7 +57,7 @@ def get_block(
def get_output_blocks(
sch: tir.Schedule,
blocks: List[BlockInfo],
blocks: list[BlockInfo],
):
"""Get the output blocks of a schedule.
......@@ -89,8 +90,8 @@ def get_output_blocks(
def try_inline(
sch: tir.Schedule,
blocks: List[BlockInfo],
) -> List[BlockInfo]:
blocks: list[BlockInfo],
) -> list[BlockInfo]:
"""Try to inline as many blocks as possible, and return the remaining blocks.
Parameters
......@@ -127,8 +128,8 @@ def try_inline(
def try_inline_contiguous_spatial(
sch: tir.Schedule,
block_infos: List[BlockInfo],
) -> List[BlockInfo]:
block_infos: list[BlockInfo],
) -> list[BlockInfo]:
"""Try to inline contiguous spatial blocks in a schedule
Parameters
......
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Union, Tuple, Dict
from tvm import tir
from tvm.ir import Range
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap
......@@ -57,7 +57,7 @@ def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
def auto_inline_producers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
skip_blocks: Optional[List[tir.schedule.BlockRV]] = None,
skip_blocks: list[tir.schedule.BlockRV] | None = None,
):
skip_blocks = skip_blocks or []
while True:
......@@ -118,7 +118,7 @@ def auto_inline_consumer_chain(
# used to match the similar region with dequantize op.
def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer):
def find_first_similar_region(regions: list[BufferRegion], buffer: tir.Buffer):
for region in regions:
if len(region.buffer.shape) == len(buffer.shape):
return region
......@@ -126,7 +126,7 @@ def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer):
# used to match the similar buffer with dequantize op.
def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer):
def find_first_similar_buffer(regions: list[BufferRegion], buffer: tir.Buffer):
for region in regions:
if len(region.buffer.shape) == len(buffer.shape):
return region.buffer
......@@ -134,7 +134,7 @@ def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer):
# find the block that required to be reindex and scope.
def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]:
def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> BlockRV | None:
# block that most near to the arguments
block = main_block
buffer = buffer
......@@ -209,11 +209,11 @@ class IterTrait:
def make_iter_fusion_index_map(
traits: List[IterTrait],
kind_order: List[IterKind],
traits: list[IterTrait],
kind_order: list[IterKind],
) -> tir.IndexMap:
fused_iters: Dict[IterKind, PrimExpr] = {}
input_iters: List[tir.Var] = []
fused_iters: dict[IterKind, PrimExpr] = {}
input_iters: list[tir.Var] = []
for i, trait in enumerate(traits):
v_i = tir.Var(f"i{i}", trait.extent.dtype)
input_iters.append(v_i)
......@@ -226,14 +226,14 @@ def make_iter_fusion_index_map(
else:
fused_iters[trait.kind] = v_i
final_indices: List[tir.PrimExpr] = [
final_indices: list[tir.PrimExpr] = [
fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
]
return tir.IndexMap(input_iters, final_indices, None)
def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
"""Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
Parameters
......@@ -252,8 +252,8 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
if len(block.reads) != 2 or len(block.writes) != 1:
return None
def get_access_axes(region: List[Range]) -> Set[Var]:
axes: Set[Var] = set()
def get_access_axes(region: list[Range]) -> set[Var]:
axes: set[Var] = set()
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
......@@ -267,7 +267,7 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
except ValueError:
return None
traits: Dict[Var, IterTrait] = {}
traits: dict[Var, IterTrait] = {}
for iter_var in block.iter_vars:
var = iter_var.var
kind: IterKind
......@@ -308,7 +308,7 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
def get_index_map(block: tir.Block,
layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]:
layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None:
"""Get index maps for the block
Parameters
......@@ -334,8 +334,8 @@ def get_index_map(block: tir.Block,
return None
A_traits, B_traits, C_traits, block_traits = traits
def get_ordered_axes(region: List[Range]) -> Set[Var]:
axes: List[Var] = []
def get_ordered_axes(region: list[Range]) -> set[Var]:
axes: list[Var] = []
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
......@@ -352,11 +352,11 @@ def get_index_map(block: tir.Block,
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)
def check_last_trait(region: List[Range]):
def check_last_trait(region: list[Range]):
axes = get_ordered_axes(region)
return has_common_reduce(axes[-1])
def infer_layout(layout: str, region: List[Range], kind: str = "A"):
def infer_layout(layout: str, region: list[Range], kind: str = "A"):
"""
Infer the layout based on the region and the kind of buffer
kind: "A", "B", "C"
......@@ -409,7 +409,7 @@ def get_index_map(block: tir.Block,
)
def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
def get_in_out_dtypes(block: tir.Block) -> tuple[str]:
"""
Detect In/Out data types for the given block based on the analysis if read/write buffers.
"""
......@@ -419,7 +419,7 @@ def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
return (in_dtype, out_dtype)
def get_dequantize_block(sch, blocks) -> Optional[BlockRV]:
def get_dequantize_block(sch, blocks) -> BlockRV | None:
# check at least two input and one output
# at lease one input has uint dtype, and the output dtype is float
def is_dequantize(block: BlockRV) -> bool:
......@@ -445,8 +445,8 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
if not isinstance(block_stmt.body.value, tir.BufferLoad):
return False, False
def get_access_vars(region: List[Range]) -> List[Var]:
axes: List[Var] = []
def get_access_vars(region: list[Range]) -> list[Var]:
axes: list[Var] = []
for r in region:
if not _is_one(r.extent):
return None
......@@ -475,7 +475,7 @@ def is_transpose_block(block_stmt: tir.Block) -> bool:
return is_identity_or_transpose_block(block_stmt)[1]
def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]):
def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]):
result_blocks = []
for block in blocks:
if not is_transpose_block(sch.get(block)):
......@@ -493,7 +493,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]
def normalize_to_matmul(sch: tir.Schedule,
main_block: BlockRV,
layout: Optional[List[str]] = None) -> Optional[tir.Schedule]:
layout: list[str] | None = None) -> tir.Schedule | None:
if layout is None:
layout = ["n", "t", "n"]
block_stmt = sch.get(main_block)
......@@ -521,10 +521,10 @@ def normalize_to_matmul(sch: tir.Schedule,
def get_tensorized_func_and_tags(
func: tir.PrimFunc,
target: Target,
layout: Optional[List[str]] = None,
layout: list[str] | None = None,
skip_normalize: bool = False,
allow_gemv: bool = False,
) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]:
) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]:
"""
transform function to matmul if necessary (e.g. transform conv2d with im2col)
"""
......@@ -554,9 +554,8 @@ def get_tensorized_func_and_tags(
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1
def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV,
target: Target) -> Union[bool, Dict]:
tags: Dict[str, Union[List[int], int]] = {}
def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool | dict:
tags: dict[str, list[int] | int] = {}
block_stmt = sch.get(block)
# Nvidia Only Support Tensor Core for
......@@ -584,8 +583,8 @@ def get_tensorized_func_and_tags(
tags["use_async_copy"] = True
# analysis intrin information
def get_ordered_axes(region: List[Range]) -> Set[Var]:
axes: List[Var] = []
def get_ordered_axes(region: list[Range]) -> set[Var]:
axes: list[Var] = []
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
......@@ -602,7 +601,7 @@ def get_tensorized_func_and_tags(
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)
def check_last_trait(region: List[Range]):
def check_last_trait(region: list[Range]):
axes = get_ordered_axes(region)
return has_common_reduce(axes[-1])
......
......@@ -17,7 +17,7 @@ class Block:
self.end = max(self.end, other.end)
def __repr__(self) -> str:
return "<Block offset={} size={}>".format(self.start, self.size())
return f"<Block offset={self.start} size={self.size()}>"
class BestFit:
......
"""Hint definition for schedule"""
from __future__ import annotations
from tvm import DataType
from typing import Dict, List, Tuple
from . import PrimFuncNode
import numpy as np
from .rasterization import *
......@@ -13,17 +13,17 @@ class TensorCoreExtraConfig:
def __init__(
self,
AS_shape: Tuple[int],
BS_shape: Tuple[int],
AF_shape: Tuple[int],
BF_shape: Tuple[int],
tc_axis: Tuple[int],
AS_shape: tuple[int],
BS_shape: tuple[int],
AF_shape: tuple[int],
BF_shape: tuple[int],
tc_axis: tuple[int],
) -> None:
self.AS_shape: Tuple[int] = AS_shape
self.BS_shape: Tuple[int] = BS_shape
self.AF_shape: Tuple[int] = AF_shape
self.BF_shape: Tuple[int] = BF_shape
self.tc_axis: Tuple[int] = tc_axis
self.AS_shape: tuple[int] = AS_shape
self.BS_shape: tuple[int] = BS_shape
self.AF_shape: tuple[int] = AF_shape
self.BF_shape: tuple[int] = BF_shape
self.tc_axis: tuple[int] = tc_axis
class Stride:
......@@ -45,7 +45,7 @@ class Stride:
def stride(self) -> int:
return self._stride
def compute_strides_from_shape(self, shape: List[int]) -> List[int]:
def compute_strides_from_shape(self, shape: list[int]) -> list[int]:
ndim = len(shape)
strides = [1 for _ in shape]
for i in range(ndim - 2, -1, -1):
......@@ -55,7 +55,7 @@ class Stride:
strides[i] = int(strides[i + 1] * shape[i + 1])
return strides
def compute_elements_from_shape(self, shape: List[int]) -> int:
def compute_elements_from_shape(self, shape: list[int]) -> int:
original_shape = np.prod(shape)
if not self.is_valid():
strided_elem = original_shape
......@@ -94,10 +94,10 @@ class TileDict:
self.grid_size = -1
self.valid = True
def get_tile(self, func) -> List[int]:
def get_tile(self, func) -> list[int]:
return self.tile_map[func]
def get_rstep(self, node) -> Dict[str, int]:
def get_rstep(self, node) -> dict[str, int]:
return self.rstep_map[node]
def __hash__(self) -> int:
......@@ -147,7 +147,7 @@ class IntrinInfo:
return self.weight_transform_kind >= 1
class Hint(object):
class Hint:
"""
Central configuration class for managing various parameters of computational tasks.
"""
......@@ -178,15 +178,15 @@ class Hint(object):
# Experimental
self._raxis_order = []
self._step = []
self.vectorize: Dict[str, int] = {}
self.vectorize: dict[str, int] = {}
self.pipeline_stage = 1
self.use_async = False
self.opt_shapes: Dict[str, int] = {}
self.opt_shapes: dict[str, int] = {}
self.intrin_info = IntrinInfo("float16", "float16", True)
self.shared_scope: str = "shared"
self.pass_context: Dict = {}
self.pass_context: dict = {}
def to_dict(self) -> Dict:
def to_dict(self) -> dict:
dic = {}
dic["block"] = self.block
if self.use_tc:
......@@ -218,7 +218,7 @@ class Hint(object):
return dic
@classmethod
def from_dict(cls, dic: Dict) -> "Hint":
def from_dict(cls, dic: dict) -> Hint:
hint = cls()
for k, v in dic.items():
setattr(hint, k, v)
......@@ -231,13 +231,13 @@ class Hint(object):
return self
@property
def raxis_order(self) -> List[int]:
def raxis_order(self) -> list[int]:
if self._raxis_order != []:
return self._raxis_order
return list(range(len(self.rstep)))
@property
def step(self) -> List[int]:
def step(self) -> list[int]:
if self._step != []:
return self._step
return [1 for _ in self.block]
......
"""PrimFunc Wrapper and Block information Analaysis"""
from __future__ import annotations
import tvm
from tvm import tir
from tvm.tir import IterVar, PrimFunc
from typing import Any, Dict, List, Tuple, Optional
from typing import Any
from tvm.tir.schedule.schedule import BlockRV
import numpy as np
import functools
......@@ -29,11 +30,11 @@ def pre_order_traverse(block_analyzer, blocks, func):
_traverse(block)
class BlockAnalyzer(object):
class BlockAnalyzer:
def __init__(self, sch) -> None:
self.sch: tir.Schedule = sch
self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch)
self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch)
def get_block_name(self, block: BlockRV) -> str:
return self.sch.get(block).name_hint
......@@ -44,7 +45,7 @@ class BlockAnalyzer(object):
return block_info
return None
def get_spatial_axis(self, block: BlockRV) -> List[IterVar]:
def get_spatial_axis(self, block: BlockRV) -> list[IterVar]:
block_info = self.get_block_info(block)
axis = []
for iter in block_info.iters:
......@@ -52,7 +53,7 @@ class BlockAnalyzer(object):
axis.append(iter)
return axis
def get_reduce_axis(self, block: BlockRV) -> List[IterVar]:
def get_reduce_axis(self, block: BlockRV) -> list[IterVar]:
block_info = self.get_block_info(block)
raxis = []
for iter in block_info.iters:
......@@ -60,39 +61,39 @@ class BlockAnalyzer(object):
raxis.append(iter)
return raxis
def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]:
def get_input_buffers(self, block: BlockRV) -> list[tir.Buffer]:
buffers = []
for read in self.sch.get(block).reads:
buffers.append(read.buffer)
return buffers
def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]:
def get_output_buffers(self, block: BlockRV) -> list[tir.Buffer]:
buffers = []
for write in self.sch.get(block).writes:
buffers.append(write.buffer)
return buffers
def get_buffers(self, block: BlockRV) -> List[tir.Buffer]:
def get_buffers(self, block: BlockRV) -> list[tir.Buffer]:
return self.get_input_buffers(block) + self.get_output_buffers(block)
def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]:
def get_producer_blocks(self, block: BlockRV) -> list[BlockRV]:
return self.sch.get_producers(block)
def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]:
def get_consumer_blocks(self, block: BlockRV) -> list[BlockRV]:
return self.sch.get_consumers(block)
@dataclass
class Edge:
src_node: 'Node'
dst_node: 'Node'
src_node: Node
dst_node: Node
src_id: int
dst_id: int
class Node(object):
class Node:
def __init__(self, tags: Optional[Dict] = None, name: str = "Node") -> None:
def __init__(self, tags: dict | None = None, name: str = "Node") -> None:
self.name = name
if tags is None:
tags = {}
......@@ -100,10 +101,10 @@ class Node(object):
self._in_edges = []
self._shapes = []
self._dtypes = []
self._tag: Dict = {}
self._tag: dict = {}
self.update_tags(tags)
def update_tags(self, tags: Dict) -> None:
def update_tags(self, tags: dict) -> None:
for tag in tags:
self.add_tag(tag, tags[tag])
......@@ -125,11 +126,11 @@ class Node(object):
return False
@property
def inputs(self) -> List[Edge]:
def inputs(self) -> list[Edge]:
return self._in_edges
@property
def outputs(self) -> List[Edge]:
def outputs(self) -> list[Edge]:
return self._out_edges
def set_inputs(self, i: int, edge: Edge):
......@@ -153,10 +154,10 @@ class Node(object):
assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype
def get_shape(self, id: int = 0) -> List[int]:
def get_shape(self, id: int = 0) -> list[int]:
return self._shapes[id]
def set_shape(self, shape: List[int], id=0, overwrite=False) -> None:
def set_shape(self, shape: list[int], id=0, overwrite=False) -> None:
if len(self._shapes) <= id:
self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)])
# elif self._shapes[id] is not None and not overwrite:
......@@ -191,15 +192,15 @@ class PrimFuncNode(Node):
def __init__(self,
prim_func: PrimFunc,
tags: Optional[Dict] = None,
tags: dict | None = None,
name: str = "PrimFuncNode") -> None:
super().__init__(tags, name=name)
self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func)
self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch)
self.schedule_stages: List[BlockRV] = []
self.blocks: List[BlockRV] = []
self.output_blocks: List[BlockRV] = None
self.schedule_stages: list[BlockRV] = []
self.blocks: list[BlockRV] = []
self.output_blocks: list[BlockRV] = None
self.reduction_block: BlockRV = None
self.raxis = []
self.input_buffers = []
......@@ -219,7 +220,7 @@ class PrimFuncNode(Node):
self.set_dtype(tvm.DataType(buffer.dtype), output_id)
def _assign_placeholder_node(self):
inputs: List[Node] = []
inputs: list[Node] = []
for buffer in self.input_buffers:
inputs.append(PlaceHolderNode(buffer.name))
......@@ -301,8 +302,8 @@ class PrimFuncNode(Node):
else:
return value
@functools.lru_cache()
def get_space_dim(self) -> List[int]:
@functools.lru_cache
def get_space_dim(self) -> list[int]:
dim_size = []
if self.reduction_block:
block_info = self.block_analyzer.get_block_info(self.reduction_block)
......@@ -333,7 +334,7 @@ class PrimFuncNode(Node):
def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType:
return tvm.DataType(buffer.dtype)
def propagate(self, tile, rstep: Optional[Dict] = None, targets=None):
def propagate(self, tile, rstep: dict | None = None, targets=None):
if rstep is None:
rstep = {}
shape = {
......@@ -343,7 +344,7 @@ class PrimFuncNode(Node):
}
return self.ana.infer(shape, rstep, targets)
def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
def propagate_inputs(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
......@@ -363,7 +364,7 @@ class PrimFuncNode(Node):
return results
# Propagate inputs only on reduction block
def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
def propagate_inputs_on_reduction(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None:
rstep = {}
reduction_block = self.reduction_block
......@@ -386,7 +387,7 @@ class PrimFuncNode(Node):
results.append(trimmed_shape)
return results
def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
def propagate_outputs(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
......@@ -399,9 +400,7 @@ class PrimFuncNode(Node):
results.append(trimmed_shape)
return results
def propagate_reduction_inputs(self,
shape,
rstep: Optional[Dict] = None) -> Dict[str, List[int]]:
def propagate_reduction_inputs(self, shape, rstep: dict | None = None) -> dict[str, list[int]]:
if rstep is None:
rstep = {}
if self.reduction_block is None:
......@@ -418,8 +417,8 @@ class PrimFuncNode(Node):
for b in self.block_analyzer.get_input_buffers(self.reduction_block)
}
@functools.lru_cache()
def infer_tensorcore_axis(self) -> Tuple[int]:
@functools.lru_cache
def infer_tensorcore_axis(self) -> tuple[int]:
# axis is fixed for one expression, so only inference and cached
assert self.get_tag("tensorcore_config")
......@@ -461,7 +460,7 @@ class PrimFuncNode(Node):
tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n)
return tc_axis
def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int:
def footprint(self, shape, rstep, stride_map: dict | None = None) -> int:
if stride_map is None:
stride_map = {}
result = 0
......@@ -510,7 +509,7 @@ class PrimFuncNode(Node):
result += buffer_len
return result, cached_tensor
def get_input_buffers(self) -> List[tir.Buffer]:
def get_input_buffers(self) -> list[tir.Buffer]:
return self.block_analyzer.input_buffers
......@@ -537,7 +536,7 @@ class OutputNode(Node):
return "output"
def topo_order(list_of_nodes) -> List[Node]:
def topo_order(list_of_nodes) -> list[Node]:
input_ready_count = {node: len(node.inputs) for node in list_of_nodes}
ready = list(filter(lambda node: input_ready_count[node] == 0, list_of_nodes))
output_list = []
......@@ -557,7 +556,7 @@ def topo_order(list_of_nodes) -> List[Node]:
return output_list
def find_topo_sort_priority(output_node_list) -> List[Node]:
def find_topo_sort_priority(output_node_list) -> list[Node]:
import sys
sys.setrecursionlimit(10000)
......@@ -591,7 +590,7 @@ def find_topo_sort_priority(output_node_list) -> List[Node]:
return topo_order
def find_topo_sort(output_node_list) -> List[Node]:
def find_topo_sort(output_node_list) -> list[Node]:
def topo_sort_dfs(node, visited, topo_order):
if node in visited:
......
from typing import List
from __future__ import annotations
import numpy as np
def get_all_factors(n: int) -> List[int]:
def get_all_factors(n: int) -> list[int]:
# Calculate the square root of n and round it up to the nearest integer
n0 = int(np.ceil(np.sqrt(n)))
......@@ -16,7 +16,7 @@ def get_all_factors(n: int) -> List[int]:
return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])]
def factorize(n: int) -> List[int]:
def factorize(n: int) -> list[int]:
i = 2 # Start with the smallest prime number
result = []
......@@ -30,7 +30,7 @@ def factorize(n: int) -> List[int]:
return result
def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int:
def coalesced_factor(subtensor: list[int], tensor: list[int]) -> int:
# If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension
if subtensor[-1] != tensor[-1] or len(subtensor) == 1:
return subtensor[-1]
......@@ -39,7 +39,7 @@ def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int:
return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1])
def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int:
def coalesced_tensor_shape(subtensor: list[int], tensor: list[int], transaction_size: int) -> int:
# Calculate the total number of elements in the subtensor
bytes = int(np.prod(subtensor))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment