"testing/vscode:/vscode.git/clone" did not exist on "6efeb7437c8d529fc9e9461c6c924b0c91d064a1"
Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
"""Policy for cuda core schedule""" """Policy for cuda core schedule"""
from __future__ import annotations
import functools import functools
import math import math
from queue import PriorityQueue from queue import PriorityQueue
from typing import Iterable, Dict, List, Optional from typing import Iterable
import numpy as np import numpy as np
import tvm import tvm
...@@ -22,11 +23,11 @@ class DefaultPolicy: ...@@ -22,11 +23,11 @@ class DefaultPolicy:
""" """
func: tvm.tir.PrimFunc func: tvm.tir.PrimFunc
nodes: List[PrimFuncNode] = [] nodes: list[PrimFuncNode] = []
arch: TileDevice arch: TileDevice
tags: Dict tags: dict
def __init__(self, arch: TileDevice, tags: Optional[Dict] = None) -> None: def __init__(self, arch: TileDevice, tags: dict | None = None) -> None:
if tags is None: if tags is None:
tags = {} tags = {}
...@@ -38,20 +39,17 @@ class DefaultPolicy: ...@@ -38,20 +39,17 @@ class DefaultPolicy:
def from_prim_func(cls, def from_prim_func(cls,
func: tvm.tir.PrimFunc, func: tvm.tir.PrimFunc,
arch: TileDevice, arch: TileDevice,
tags: Optional[Dict] = None, tags: dict | None = None,
name: str = "PrimFuncNode"): name: str = "PrimFuncNode"):
return cls(arch, tags)._init_with_prim_func(func, name) return cls(arch, tags)._init_with_prim_func(func, name)
@classmethod @classmethod
def from_output_nodes(cls, def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None):
nodes: List[OutputNode],
arch: TileDevice,
tags: Optional[Dict] = None):
return cls(arch, tags)._init_with_output_nodes(nodes) return cls(arch, tags)._init_with_output_nodes(nodes)
def _init_with_prim_func(self, def _init_with_prim_func(self,
func: tvm.tir.PrimFunc, func: tvm.tir.PrimFunc,
name: str = "PrimFuncNode") -> "DefaultPolicy": name: str = "PrimFuncNode") -> DefaultPolicy:
if func is not None and isinstance(func, tvm.tir.PrimFunc): if func is not None and isinstance(func, tvm.tir.PrimFunc):
self.func = func self.func = func
self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name) self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name)
...@@ -61,7 +59,7 @@ class DefaultPolicy: ...@@ -61,7 +59,7 @@ class DefaultPolicy:
self._init_with_output_nodes(output_nodes) self._init_with_output_nodes(output_nodes)
return self return self
def _init_with_output_nodes(self, output_nodes: List[OutputNode]): def _init_with_output_nodes(self, output_nodes: list[OutputNode]):
self.ordered_nodes = list( self.ordered_nodes = list(
filter(lambda n: not n.is_placeholder() and not n.is_output(), filter(lambda n: not n.is_placeholder() and not n.is_output(),
find_topo_sort(output_nodes))) find_topo_sort(output_nodes)))
...@@ -78,7 +76,7 @@ class DefaultPolicy: ...@@ -78,7 +76,7 @@ class DefaultPolicy:
self.output_nodes.append(node) self.output_nodes.append(node)
return self return self
def emit_config(self, topk: int) -> List[Hint]: def emit_config(self, topk: int) -> list[Hint]:
base_tile = self.get_base_tile() base_tile = self.get_base_tile()
if base_tile is None: if base_tile is None:
return [] return []
...@@ -557,7 +555,7 @@ class DefaultPolicy: ...@@ -557,7 +555,7 @@ class DefaultPolicy:
node, td) node, td)
td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map
def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict: def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict:
""" """
Computes and returns a TileDict object for a given output tile configuration and reduction step map. Computes and returns a TileDict object for a given output tile configuration and reduction step map.
...@@ -624,7 +622,7 @@ class DefaultPolicy: ...@@ -624,7 +622,7 @@ class DefaultPolicy:
return True return True
def recommend_block_size(self, td: TileDict) -> List[int]: def recommend_block_size(self, td: TileDict) -> list[int]:
""" """
Recommends optimal block sizes based on the TileDict configuration. Recommends optimal block sizes based on the TileDict configuration.
......
"""Policy for tensorcore schedule""" """Policy for tensorcore schedule"""
from __future__ import annotations
import tvm import tvm
from typing import Dict, List, Tuple, Optional
import numpy as np import numpy as np
import logging import logging
from ..hint import Hint, Stride, TileDict, IntrinInfo from ..hint import Hint, Stride, TileDict, IntrinInfo
...@@ -19,9 +19,9 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -19,9 +19,9 @@ class TensorCorePolicy(DefaultPolicy):
wmma_k: int = 16 wmma_k: int = 16
pipeline_stage: int = 1 pipeline_stage: int = 1
use_async_copy: bool = False use_async_copy: bool = False
block_reduction_depth: Optional[int] = None block_reduction_depth: int | None = None
def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: Optional[str] = None): def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str | None = None):
super()._init_with_prim_func(func, name) super()._init_with_prim_func(func, name)
self._legalize_info() self._legalize_info()
return self return self
...@@ -52,9 +52,9 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -52,9 +52,9 @@ class TensorCorePolicy(DefaultPolicy):
def _compute_tc_strides( def _compute_tc_strides(
self, self,
node: PrimFuncNode, node: PrimFuncNode,
tile: List[int], tile: list[int],
rstep: Optional[Dict[str, int]] = None, rstep: dict[str, int] | None = None,
) -> Tuple[Stride, Stride, Stride]: ) -> tuple[Stride, Stride, Stride]:
if rstep is None: if rstep is None:
rstep = {} rstep = {}
# strides was used for shared memory padding. which is necessary for avoiding # strides was used for shared memory padding. which is necessary for avoiding
......
"""Rasteration Plan For L2 Cache Locality""" """Rasteration Plan For L2 Cache Locality"""
from __future__ import annotations
from typing import List
class Rasterization: class Rasterization:
...@@ -10,7 +9,7 @@ class Rasterization: ...@@ -10,7 +9,7 @@ class Rasterization:
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def get_code(self) -> List[str]: def get_code(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
@property @property
...@@ -27,7 +26,7 @@ class NoRasterization(Rasterization): ...@@ -27,7 +26,7 @@ class NoRasterization(Rasterization):
def __repr__(self) -> str: def __repr__(self) -> str:
return "<NoRasterization>" return "<NoRasterization>"
def get_code(self) -> List[str]: def get_code(self) -> list[str]:
return [] return []
...@@ -47,7 +46,7 @@ class Rasterization2DRow(Rasterization): ...@@ -47,7 +46,7 @@ class Rasterization2DRow(Rasterization):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Rasterization2DRow({self.panel_width_})>" return f"<Rasterization2DRow({self.panel_width_})>"
def get_code(self) -> List[str]: def get_code(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
...@@ -84,10 +83,10 @@ __device__ __inline__ dim3 rasterization2DColumn(const int panel_width) { ...@@ -84,10 +83,10 @@ __device__ __inline__ dim3 rasterization2DColumn(const int panel_width) {
} }
""" """
def get_code(self, panel_width: int = None) -> List[str]: def get_code(self, panel_width: int = None) -> list[str]:
if panel_width is None: if panel_width is None:
panel_width = self.panel_width_ panel_width = self.panel_width_
return [ return [
self.get_device_function(), self.get_device_function(),
"const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width), f"const dim3 blockIdx = rasterization2DColumn({panel_width});\n",
] ]
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List
from tvm import arith from tvm import arith
class Statement(): class Statement:
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict,
range_map: OrderedDict): range_map: OrderedDict):
...@@ -18,12 +18,12 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): ...@@ -18,12 +18,12 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
class InputShapeInference(): class InputShapeInference:
def __init__(self, deps: List[Statement]): def __init__(self, deps: list[Statement]):
self.deps = deps self.deps = deps
def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]): def _infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int]):
shape = shape.copy() shape = shape.copy()
ana = arith.Analyzer() ana = arith.Analyzer()
for dep in reversed(self.deps): for dep in reversed(self.deps):
...@@ -44,7 +44,7 @@ class InputShapeInference(): ...@@ -44,7 +44,7 @@ class InputShapeInference():
shape[name] = [c.max_value - c.min_value + 1 for c in bounds] shape[name] = [c.max_value - c.min_value + 1 for c in bounds]
return shape return shape
def infer(self, shape, rstep: Dict[str, int] = None): def infer(self, shape, rstep: dict[str, int] = None):
if rstep is None: if rstep is None:
rstep = {} rstep = {}
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
......
from typing import Dict, List, Tuple, Set, Mapping from __future__ import annotations
from typing import Mapping
from tvm.tir.schedule.schedule import BlockRV from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal from tvm.ir import structural_equal
from tvm import arith, tir from tvm import arith, tir
...@@ -15,7 +16,7 @@ class Statement: ...@@ -15,7 +16,7 @@ class Statement:
self.reverse_bound_inference = {} self.reverse_bound_inference = {}
def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]): def make_reverse(self, input_name: str, input_iter: list[tir.PrimExpr]):
if len(self.block_analyzer.get_reduce_axis(self.block)) > 0: if len(self.block_analyzer.get_reduce_axis(self.block)) > 0:
return None return None
if len(self.dependent_region[input_name]) != 1: if len(self.dependent_region[input_name]) != 1:
...@@ -47,7 +48,7 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): ...@@ -47,7 +48,7 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
class TensorDepNode(object): class TensorDepNode:
""" """
For tensor dependency analysis. For tensor dependency analysis.
""" """
...@@ -76,7 +77,7 @@ class TensorDepNode(object): ...@@ -76,7 +77,7 @@ class TensorDepNode(object):
return self.name return self.name
class DependencyAnalysis(object): class DependencyAnalysis:
def __init__(self, deps): def __init__(self, deps):
self.deps = deps self.deps = deps
...@@ -89,7 +90,7 @@ class DependencyAnalysis(object): ...@@ -89,7 +90,7 @@ class DependencyAnalysis(object):
This is a workaround for the issue that we have two same ops' fuse case. This is a workaround for the issue that we have two same ops' fuse case.
See https://github.com/apache/tvm/issues/16433 See https://github.com/apache/tvm/issues/16433
""" """
_names: Set = set() _names: set = set()
name2dep: Mapping = {} name2dep: Mapping = {}
for dep in deps: for dep in deps:
output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0]
...@@ -168,7 +169,7 @@ class DependencyAnalysis(object): ...@@ -168,7 +169,7 @@ class DependencyAnalysis(object):
class InputShapeInference: class InputShapeInference:
def __init__(self, deps: List[Statement]): def __init__(self, deps: list[Statement]):
self.deps = deps self.deps = deps
self.target_mapping = {} self.target_mapping = {}
self.buffer_mapping = {} self.buffer_mapping = {}
...@@ -179,7 +180,7 @@ class InputShapeInference: ...@@ -179,7 +180,7 @@ class InputShapeInference:
self.dep_analysis = DependencyAnalysis(self.deps) self.dep_analysis = DependencyAnalysis(self.deps)
self.dep_analysis.analyze() self.dep_analysis.analyze()
def construct_dependency_target(self, targets: Tuple[str]): def construct_dependency_target(self, targets: tuple[str]):
if targets in self.target_mapping: if targets in self.target_mapping:
return self.target_mapping[targets] return self.target_mapping[targets]
# should be buffer name instead of block name # should be buffer name instead of block name
...@@ -242,8 +243,8 @@ class InputShapeInference: ...@@ -242,8 +243,8 @@ class InputShapeInference:
return input_vars, mapping return input_vars, mapping
def infer(self, def infer(self,
shape: Dict[str, List[arith.ConstIntBound]], shape: dict[str, list[arith.ConstIntBound]],
rstep: Dict[str, int] = None, rstep: dict[str, int] = None,
targets=None): targets=None):
if rstep is None: if rstep is None:
rstep = {} rstep = {}
...@@ -351,10 +352,10 @@ def walk_indice(expr): ...@@ -351,10 +352,10 @@ def walk_indice(expr):
elif isinstance(expr, tir.Call): elif isinstance(expr, tir.Call):
return None return None
else: else:
raise Exception("Unhandled node type in walk_indice(): %s" % expr) raise Exception(f"Unhandled node type in walk_indice(): {expr}")
def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]: def _extract_dependent_region(block_analyzer, block: BlockRV) -> dict[str, list[tir.PrimExpr]]:
input_buffers = block_analyzer.get_input_buffers(block) input_buffers = block_analyzer.get_input_buffers(block)
dependent_region = {buffer.name: [] for buffer in input_buffers} dependent_region = {buffer.name: [] for buffer in input_buffers}
......
# Import necessary modules and classes # Import necessary modules and classes
from __future__ import annotations
from abc import ABC, abstractmethod # For defining abstract base classes from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes from dataclasses import dataclass, field # For defining data classes
from ..arch import ( # Import architecture-related utilities and classes from ..arch import ( # Import architecture-related utilities and classes
TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch) TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch)
from ..roller.hint import Hint # Import the Hint class from ..roller.hint import Hint # Import the Hint class
from ..roller.node import OutputNode # Import the OutputNode class from ..roller.node import OutputNode # Import the OutputNode class
from typing import List # For type hinting
from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
...@@ -24,10 +24,10 @@ class BaseTemplate(ABC): ...@@ -24,10 +24,10 @@ class BaseTemplate(ABC):
_func: PrimFunc = field(default=None, init=False, repr=False) _func: PrimFunc = field(default=None, init=False, repr=False)
# The outputs nodes associated with this template, initially None # The outputs nodes associated with this template, initially None
_output_nodes: List[OutputNode] = field(default=None, init=False, repr=False) _output_nodes: list[OutputNode] = field(default=None, init=False, repr=False)
@abstractmethod @abstractmethod
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
""" """
Abstract method that must be implemented by subclasses. Abstract method that must be implemented by subclasses.
It should return a list of hardware-aware configurations (hints) It should return a list of hardware-aware configurations (hints)
...@@ -42,7 +42,7 @@ class BaseTemplate(ABC): ...@@ -42,7 +42,7 @@ class BaseTemplate(ABC):
""" """
pass pass
def with_arch(self, arch: TileDevice) -> "BaseTemplate": def with_arch(self, arch: TileDevice) -> BaseTemplate:
""" """
Sets the architecture for this template and returns itself. Sets the architecture for this template and returns itself.
...@@ -110,7 +110,7 @@ class BaseTemplate(ABC): ...@@ -110,7 +110,7 @@ class BaseTemplate(ABC):
""" """
raise NotImplementedError("initialize_function is not implemented") raise NotImplementedError("initialize_function is not implemented")
def set_function(self, func: PrimFunc) -> "BaseTemplate": def set_function(self, func: PrimFunc) -> BaseTemplate:
""" """
Sets the function for this template and returns itself. Sets the function for this template and returns itself.
...@@ -123,7 +123,7 @@ class BaseTemplate(ABC): ...@@ -123,7 +123,7 @@ class BaseTemplate(ABC):
self._func = func self._func = func
return self return self
def set_output_nodes(self, output_nodes: List[OutputNode]) -> "BaseTemplate": def set_output_nodes(self, output_nodes: list[OutputNode]) -> BaseTemplate:
""" """
Sets the output nodes for this template and returns itself. Sets the output nodes for this template and returns itself.
...@@ -136,7 +136,7 @@ class BaseTemplate(ABC): ...@@ -136,7 +136,7 @@ class BaseTemplate(ABC):
self._output_nodes = output_nodes self._output_nodes = output_nodes
return self return self
def recommend_hints(self, topk: int = 10) -> List[Hint]: def recommend_hints(self, topk: int = 10) -> list[Hint]:
""" """
Provides a list of recommended hardware-aware configurations. Provides a list of recommended hardware-aware configurations.
...@@ -159,7 +159,7 @@ class BaseTemplate(ABC): ...@@ -159,7 +159,7 @@ class BaseTemplate(ABC):
return self._arch return self._arch
@property @property
def output_nodes(self) -> List[OutputNode]: def output_nodes(self) -> list[OutputNode]:
""" """
Returns the output nodes associated with this template. Returns the output nodes associated with this template.
......
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te, tir from tvm import te, tir
from ..roller import Hint from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func from ..utils import get_roller_hints_from_func
...@@ -44,7 +44,7 @@ class ConvTemplate(BaseTemplate): ...@@ -44,7 +44,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype: str = "float16" # Data type for accumulation accum_dtype: str = "float16" # Data type for accumulation
with_bias: bool = False # Whether to add a bias term with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch=None, topk=10) -> List[Hint]: def get_hardware_aware_configs(self, arch=None, topk=10) -> list[Hint]:
""" """
Retrieves optimized hardware-aware configurations. Retrieves optimized hardware-aware configurations.
......
# Import necessary modules # Import necessary modules
from __future__ import annotations
from dataclasses import dataclass # Used for defining data classes from dataclasses import dataclass # Used for defining data classes
from .base import BaseTemplate # Importing the base class for templates from .base import BaseTemplate # Importing the base class for templates
from tvm import te # Importing TVM's tensor expression module from tvm import te # Importing TVM's tensor expression module
from ..arch import TileDevice # Importing TileDevice for hardware-specific configurations from ..arch import TileDevice # Importing TileDevice for hardware-specific configurations
from ..roller import Hint # Importing Hint for optimization hints from ..roller import Hint # Importing Hint for optimization hints
from typing import List # Importing List type hint
from ..utils import get_roller_hints_from_func # Function to obtain optimization hints from ..utils import get_roller_hints_from_func # Function to obtain optimization hints
...@@ -19,10 +19,10 @@ class ElementwiseTemplate(BaseTemplate): ...@@ -19,10 +19,10 @@ class ElementwiseTemplate(BaseTemplate):
""" """
# OP Related Config # OP Related Config
shape: List[int] = None # Shape of the tensor shape: list[int] = None # Shape of the tensor
dtype: str = "float16" # Data type of the tensor dtype: str = "float16" # Data type of the tensor
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
""" """
Retrieves hardware-aware optimization configurations. Retrieves hardware-aware optimization configurations.
......
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
from ..arch import TileDevice from ..arch import TileDevice
from ..roller import Hint from ..roller import Hint
from ..roller import PrimFuncNode, OutputNode, Edge from ..roller import PrimFuncNode, OutputNode, Edge
from typing import List
from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_tags from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_tags
@dataclass @dataclass
class FlashAttentionTemplate(BaseTemplate): class FlashAttentionTemplate(BaseTemplate):
_output_nodes: List[OutputNode] = None _output_nodes: list[OutputNode] = None
# Operation-related configuration parameters # Operation-related configuration parameters
batch_size: int = 1 batch_size: int = 1
...@@ -26,7 +26,7 @@ class FlashAttentionTemplate(BaseTemplate): ...@@ -26,7 +26,7 @@ class FlashAttentionTemplate(BaseTemplate):
out_dtype: str = "float16" out_dtype: str = "float16"
accum_dtype: str = "float16" accum_dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
""" """
Retrieves optimized hardware-aware configurations. Retrieves optimized hardware-aware configurations.
......
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
from ..arch import TileDevice from ..arch import TileDevice
from ..roller import Hint from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func from ..utils import get_roller_hints_from_func
...@@ -25,7 +25,7 @@ class GEMVTemplate(BaseTemplate): ...@@ -25,7 +25,7 @@ class GEMVTemplate(BaseTemplate):
accum_dtype: str = "float16" # Accumulation data type accum_dtype: str = "float16" # Accumulation data type
with_bias: bool = False # Whether to add a bias term with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
""" """
Retrieves optimized hardware-aware configurations. Retrieves optimized hardware-aware configurations.
......
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
from ..arch import TileDevice from ..arch import TileDevice
from ..roller import Hint from ..roller import Hint
from typing import List, Union
from ..utils import get_roller_hints_from_func from ..utils import get_roller_hints_from_func
...@@ -11,11 +11,11 @@ from ..utils import get_roller_hints_from_func ...@@ -11,11 +11,11 @@ from ..utils import get_roller_hints_from_func
class GeneralReductionTemplate(BaseTemplate): class GeneralReductionTemplate(BaseTemplate):
# OP Related Config # OP Related Config
structure: Union[str, List[str]] = None structure: str | list[str] = None
shape: List[int] = None shape: list[int] = None
dtype: str = "float16" dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
roller_hints = get_roller_hints_from_func( roller_hints = get_roller_hints_from_func(
self._func, arch=arch, topk=topk, allow_gemv=False) self._func, arch=arch, topk=topk, allow_gemv=False)
return roller_hints return roller_hints
......
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from .base import BaseTemplate from .base import BaseTemplate
from tvm import te from tvm import te
from ..arch import TileDevice from ..arch import TileDevice
from ..roller import Hint from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func from ..utils import get_roller_hints_from_func
...@@ -38,7 +38,7 @@ class MatmulTemplate(BaseTemplate): ...@@ -38,7 +38,7 @@ class MatmulTemplate(BaseTemplate):
accum_dtype: str = "float16" # Data type for accumulation accum_dtype: str = "float16" # Data type for accumulation
with_bias: bool = False # Whether to add a bias term with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
""" """
Retrieves optimized hardware-aware configurations. Retrieves optimized hardware-aware configurations.
......
from typing import List, Optional, Union from __future__ import annotations
from tvm import tir, IRModule from tvm import tir, IRModule
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from .arch import TileDevice from .arch import TileDevice
...@@ -26,11 +26,11 @@ def get_rasterization_code(pannel_width: int = 8) -> str: ...@@ -26,11 +26,11 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
""" """
def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
arch: TileDevice, arch: TileDevice,
topk: int = 10, topk: int = 10,
tensorcore_only: bool = False, tensorcore_only: bool = False,
allow_gemv: bool = False) -> Optional[List[Hint]]: allow_gemv: bool = False) -> list[Hint] | None:
func = None func = None
if isinstance(func_or_module, tir.PrimFunc): if isinstance(func_or_module, tir.PrimFunc):
func = func_or_module func = func_or_module
...@@ -69,11 +69,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], ...@@ -69,11 +69,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
return roller_hints return roller_hints
def get_roller_hints_from_output_nodes( def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
output_nodes: List[OutputNode],
arch: TileDevice, arch: TileDevice,
topk: int = 10, topk: int = 10,
extra_tags: Optional[List[str]] = None) -> Optional[List[Hint]]: extra_tags: list[str] | None = None) -> list[Hint] | None:
assert isinstance(output_nodes, list), "The input should be a list of functions." assert isinstance(output_nodes, list), "The input should be a list of functions."
lints = [] lints = []
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Util to invoke C/C++ compilers in the system.""" """Util to invoke C/C++ compilers in the system."""
from __future__ import annotations
import functools import functools
import os import os
import shutil import shutil
...@@ -23,7 +24,6 @@ import platform ...@@ -23,7 +24,6 @@ import platform
# pylint: disable=invalid-name # pylint: disable=invalid-name
import sys import sys
from typing import Dict
from tvm.base import py_str from tvm.base import py_str
from tvm.contrib import tar as _tar from tvm.contrib import tar as _tar
...@@ -208,7 +208,7 @@ def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_e ...@@ -208,7 +208,7 @@ def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_e
raise ValueError("Unsupported platform") raise ValueError("Unsupported platform")
def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]: def get_global_symbol_section_map(path, *, nm=None) -> dict[str, str]:
"""Get global symbols from a library via nm -g """Get global symbols from a library via nm -g
Parameters Parameters
......
...@@ -54,7 +54,7 @@ def compile_hip(code, ...@@ -54,7 +54,7 @@ def compile_hip(code,
if target_format not in ["hsaco"]: if target_format not in ["hsaco"]:
raise ValueError("target_format must be hsaco") raise ValueError("target_format must be hsaco")
temp_code = temp.relpath("my_kernel.cc") temp_code = temp.relpath("my_kernel.cc")
temp_target = temp.relpath("my_kernel.%s" % target_format) temp_target = temp.relpath(f"my_kernel.{target_format}")
with open(temp_code, "w") as out_file: with open(temp_code, "w") as out_file:
out_file.write(code) out_file.write(code)
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# modified from apache tvm python/tvm/contrib/nvcc.py # modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system""" """Utility to invoke nvcc compiler in the system"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from __future__ import annotations
import os import os
import subprocess import subprocess
import warnings import warnings
from typing import Tuple
from tilelang.env import CUDA_HOME from tilelang.env import CUDA_HOME
import tvm.ffi import tvm.ffi
...@@ -299,7 +299,7 @@ def get_target_compute_version(target=None): ...@@ -299,7 +299,7 @@ def get_target_compute_version(target=None):
"Try specifying it by adding '-arch=sm_xx' to your target.") "Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version) -> Tuple[int, int]: def parse_compute_version(compute_version) -> tuple[int, int]:
"""Parse compute capability string to divide major and minor version """Parse compute capability string to divide major and minor version
Parameters Parameters
......
from __future__ import annotations
import cuda.bindings.nvrtc as nvrtc import cuda.bindings.nvrtc as nvrtc
from typing import Literal, Union, List, Optional, Tuple from typing import Literal
from tvm.target import Target from tvm.target import Target
from .nvcc import get_target_compute_version, parse_compute_version from .nvcc import get_target_compute_version, parse_compute_version
def get_nvrtc_version() -> Tuple[int, int]: def get_nvrtc_version() -> tuple[int, int]:
result, major, minor = nvrtc.nvrtcVersion() result, major, minor = nvrtc.nvrtcVersion()
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get NVRTC version: {result}" assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get NVRTC version: {result}"
return (major, minor) return (major, minor)
...@@ -12,8 +13,8 @@ def get_nvrtc_version() -> Tuple[int, int]: ...@@ -12,8 +13,8 @@ def get_nvrtc_version() -> Tuple[int, int]:
def compile_cuda(code: str, def compile_cuda(code: str,
target_format: Literal["ptx", "cubin"] = "ptx", target_format: Literal["ptx", "cubin"] = "ptx",
arch: Optional[int] = None, arch: int | None = None,
options: Optional[Union[str, List[str]]] = None, options: str | list[str] | None = None,
verbose: bool = False) -> bytearray: verbose: bool = False) -> bytearray:
"""Compile cuda code with NVRTC. """Compile cuda code with NVRTC.
......
from typing import Callable, Union from __future__ import annotations
from typing import Callable
from tvm import register_func from tvm import register_func
from tvm.target import Target from tvm.target import Target
...@@ -25,7 +26,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T ...@@ -25,7 +26,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
register_func("tilelang_callback_hip_postproc", f=func, override=override) register_func("tilelang_callback_hip_postproc", f=func, override=override)
def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override: bool = True): def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering CUDA post-processing callback function. """Decorator for registering CUDA post-processing callback function.
Can be used with or without parentheses: Can be used with or without parentheses:
...@@ -58,7 +59,7 @@ def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override ...@@ -58,7 +59,7 @@ def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override
raise TypeError("Invalid decorator usage") raise TypeError("Invalid decorator usage")
def register_hip_postproc_callback(func: Union[Callable, bool] = None, override: bool = True): def register_hip_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering HIP post-processing callback function. """Decorator for registering HIP post-processing callback function.
Can be used with or without parentheses: Can be used with or without parentheses:
......
"""The compiler for TL programs.""" """The compiler for TL programs."""
from __future__ import annotations
import os import os
import os.path as osp import os.path as osp
from typing import Union, Optional, Callable, List from typing import Callable
import tilelang.transform import tilelang.transform
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import tir from tvm import tir
...@@ -114,7 +115,7 @@ def tilelang_callback_hip_compile(code, target): ...@@ -114,7 +115,7 @@ def tilelang_callback_hip_compile(code, target):
return hsaco return hsaco
def extrac_params(func: tir.PrimFunc) -> List[KernelParam]: def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
tensor_types = [] tensor_types = []
for var in func.params: for var in func.params:
if var in func.buffer_map: if var in func.buffer_map:
...@@ -124,7 +125,7 @@ def extrac_params(func: tir.PrimFunc) -> List[KernelParam]: ...@@ -124,7 +125,7 @@ def extrac_params(func: tir.PrimFunc) -> List[KernelParam]:
return tensor_types return tensor_types
def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]): def canon_target_host(target: str | Target, target_host: str | Target | None):
if not target_host: if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"
...@@ -190,9 +191,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> ...@@ -190,9 +191,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
def lower( def lower(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: tir.PrimFunc | tvm.IRModule,
target: Union[str, Target] = "auto", target: str | Target = "auto",
target_host: Optional[Union[str, Target]] = None, target_host: str | Target | None = None,
runtime_only=False, runtime_only=False,
enable_host_codegen=False, enable_host_codegen=False,
enable_device_compile=False, enable_device_compile=False,
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union, Optional
import torch import torch
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var, PrimExpr from tvm.tir import Buffer, IntImm, Var, PrimExpr
...@@ -15,7 +15,7 @@ class KernelParam: ...@@ -15,7 +15,7 @@ class KernelParam:
Used to describe tensor or scalar parameters in TVM/PyTorch interop. Used to describe tensor or scalar parameters in TVM/PyTorch interop.
""" """
dtype: torch.dtype # PyTorch data type of the parameter dtype: torch.dtype # PyTorch data type of the parameter
shape: List[Union[int, Var]] # List of dimensions, can be integers or TVM variables shape: list[int | Var] # List of dimensions, can be integers or TVM variables
@classmethod @classmethod
def from_buffer(cls, buffer: Buffer): def from_buffer(cls, buffer: Buffer):
...@@ -111,7 +111,6 @@ class CompiledArtifact: ...@@ -111,7 +111,6 @@ class CompiledArtifact:
""" """
host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution
device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code
params: List[KernelParam] # List of parameters (tensors/scalars) used by the kernel params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel
kernel_source: str # Raw source code of the generated kernel kernel_source: str # Raw source code of the generated kernel
rt_mod: Optional[ rt_mod: tvm.runtime.Module | None = None # Runtime module for execution, may be lazily initialized
tvm.runtime.Module] = None # Runtime module for execution, may be lazily initialized
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment