Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
"""Policy for cuda core schedule"""
from __future__ import annotations
import functools
import math
from queue import PriorityQueue
from typing import Iterable, Dict, List, Optional
from typing import Iterable
import numpy as np
import tvm
......@@ -22,11 +23,11 @@ class DefaultPolicy:
"""
func: tvm.tir.PrimFunc
nodes: List[PrimFuncNode] = []
nodes: list[PrimFuncNode] = []
arch: TileDevice
tags: Dict
tags: dict
def __init__(self, arch: TileDevice, tags: Optional[Dict] = None) -> None:
def __init__(self, arch: TileDevice, tags: dict | None = None) -> None:
if tags is None:
tags = {}
......@@ -38,20 +39,17 @@ class DefaultPolicy:
def from_prim_func(cls,
func: tvm.tir.PrimFunc,
arch: TileDevice,
tags: Optional[Dict] = None,
tags: dict | None = None,
name: str = "PrimFuncNode"):
return cls(arch, tags)._init_with_prim_func(func, name)
@classmethod
def from_output_nodes(cls,
nodes: List[OutputNode],
arch: TileDevice,
tags: Optional[Dict] = None):
def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None):
return cls(arch, tags)._init_with_output_nodes(nodes)
def _init_with_prim_func(self,
func: tvm.tir.PrimFunc,
name: str = "PrimFuncNode") -> "DefaultPolicy":
name: str = "PrimFuncNode") -> DefaultPolicy:
if func is not None and isinstance(func, tvm.tir.PrimFunc):
self.func = func
self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name)
......@@ -61,7 +59,7 @@ class DefaultPolicy:
self._init_with_output_nodes(output_nodes)
return self
def _init_with_output_nodes(self, output_nodes: List[OutputNode]):
def _init_with_output_nodes(self, output_nodes: list[OutputNode]):
self.ordered_nodes = list(
filter(lambda n: not n.is_placeholder() and not n.is_output(),
find_topo_sort(output_nodes)))
......@@ -78,7 +76,7 @@ class DefaultPolicy:
self.output_nodes.append(node)
return self
def emit_config(self, topk: int) -> List[Hint]:
def emit_config(self, topk: int) -> list[Hint]:
base_tile = self.get_base_tile()
if base_tile is None:
return []
......@@ -557,7 +555,7 @@ class DefaultPolicy:
node, td)
td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map
def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict:
def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict:
"""
Computes and returns a TileDict object for a given output tile configuration and reduction step map.
......@@ -624,7 +622,7 @@ class DefaultPolicy:
return True
def recommend_block_size(self, td: TileDict) -> List[int]:
def recommend_block_size(self, td: TileDict) -> list[int]:
"""
Recommends optimal block sizes based on the TileDict configuration.
......
"""Policy for tensorcore schedule"""
from __future__ import annotations
import tvm
from typing import Dict, List, Tuple, Optional
import numpy as np
import logging
from ..hint import Hint, Stride, TileDict, IntrinInfo
......@@ -19,9 +19,9 @@ class TensorCorePolicy(DefaultPolicy):
wmma_k: int = 16
pipeline_stage: int = 1
use_async_copy: bool = False
block_reduction_depth: Optional[int] = None
block_reduction_depth: int | None = None
def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: Optional[str] = None):
def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str | None = None):
super()._init_with_prim_func(func, name)
self._legalize_info()
return self
......@@ -52,9 +52,9 @@ class TensorCorePolicy(DefaultPolicy):
def _compute_tc_strides(
self,
node: PrimFuncNode,
tile: List[int],
rstep: Optional[Dict[str, int]] = None,
) -> Tuple[Stride, Stride, Stride]:
tile: list[int],
rstep: dict[str, int] | None = None,
) -> tuple[Stride, Stride, Stride]:
if rstep is None:
rstep = {}
# strides was used for shared memory padding. which is necessary for avoiding
......
"""Rasteration Plan For L2 Cache Locality"""
from typing import List
from __future__ import annotations
class Rasterization:
......@@ -10,7 +9,7 @@ class Rasterization:
def __init__(self) -> None:
pass
def get_code(self) -> List[str]:
def get_code(self) -> list[str]:
raise NotImplementedError()
@property
......@@ -27,7 +26,7 @@ class NoRasterization(Rasterization):
def __repr__(self) -> str:
return "<NoRasterization>"
def get_code(self) -> List[str]:
def get_code(self) -> list[str]:
return []
......@@ -47,7 +46,7 @@ class Rasterization2DRow(Rasterization):
def __repr__(self) -> str:
return f"<Rasterization2DRow({self.panel_width_})>"
def get_code(self) -> List[str]:
def get_code(self) -> list[str]:
raise NotImplementedError()
......@@ -84,10 +83,10 @@ __device__ __inline__ dim3 rasterization2DColumn(const int panel_width) {
}
"""
def get_code(self, panel_width: int = None) -> List[str]:
def get_code(self, panel_width: int = None) -> list[str]:
if panel_width is None:
panel_width = self.panel_width_
return [
self.get_device_function(),
"const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width),
f"const dim3 blockIdx = rasterization2DColumn({panel_width});\n",
]
from __future__ import annotations
from collections import OrderedDict
from typing import Dict, List
from tvm import arith
class Statement():
class Statement:
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict,
range_map: OrderedDict):
......@@ -18,12 +18,12 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
class InputShapeInference():
class InputShapeInference:
def __init__(self, deps: List[Statement]):
def __init__(self, deps: list[Statement]):
self.deps = deps
def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]):
def _infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int]):
shape = shape.copy()
ana = arith.Analyzer()
for dep in reversed(self.deps):
......@@ -44,7 +44,7 @@ class InputShapeInference():
shape[name] = [c.max_value - c.min_value + 1 for c in bounds]
return shape
def infer(self, shape, rstep: Dict[str, int] = None):
def infer(self, shape, rstep: dict[str, int] = None):
if rstep is None:
rstep = {}
if isinstance(shape, (list, tuple)):
......
from typing import Dict, List, Tuple, Set, Mapping
from __future__ import annotations
from typing import Mapping
from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal
from tvm import arith, tir
......@@ -15,7 +16,7 @@ class Statement:
self.reverse_bound_inference = {}
def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]):
def make_reverse(self, input_name: str, input_iter: list[tir.PrimExpr]):
if len(self.block_analyzer.get_reduce_axis(self.block)) > 0:
return None
if len(self.dependent_region[input_name]) != 1:
......@@ -47,7 +48,7 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
class TensorDepNode(object):
class TensorDepNode:
"""
For tensor dependency analysis.
"""
......@@ -76,7 +77,7 @@ class TensorDepNode(object):
return self.name
class DependencyAnalysis(object):
class DependencyAnalysis:
def __init__(self, deps):
self.deps = deps
......@@ -89,7 +90,7 @@ class DependencyAnalysis(object):
This is a workaround for the issue that we have two same ops' fuse case.
See https://github.com/apache/tvm/issues/16433
"""
_names: Set = set()
_names: set = set()
name2dep: Mapping = {}
for dep in deps:
output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0]
......@@ -168,7 +169,7 @@ class DependencyAnalysis(object):
class InputShapeInference:
def __init__(self, deps: List[Statement]):
def __init__(self, deps: list[Statement]):
self.deps = deps
self.target_mapping = {}
self.buffer_mapping = {}
......@@ -179,7 +180,7 @@ class InputShapeInference:
self.dep_analysis = DependencyAnalysis(self.deps)
self.dep_analysis.analyze()
def construct_dependency_target(self, targets: Tuple[str]):
def construct_dependency_target(self, targets: tuple[str]):
if targets in self.target_mapping:
return self.target_mapping[targets]
# should be buffer name instead of block name
......@@ -242,8 +243,8 @@ class InputShapeInference:
return input_vars, mapping
def infer(self,
shape: Dict[str, List[arith.ConstIntBound]],
rstep: Dict[str, int] = None,
shape: dict[str, list[arith.ConstIntBound]],
rstep: dict[str, int] = None,
targets=None):
if rstep is None:
rstep = {}
......@@ -351,10 +352,10 @@ def walk_indice(expr):
elif isinstance(expr, tir.Call):
return None
else:
raise Exception("Unhandled node type in walk_indice(): %s" % expr)
raise Exception(f"Unhandled node type in walk_indice(): {expr}")
def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]:
def _extract_dependent_region(block_analyzer, block: BlockRV) -> dict[str, list[tir.PrimExpr]]:
input_buffers = block_analyzer.get_input_buffers(block)
dependent_region = {buffer.name: [] for buffer in input_buffers}
......
# Import necessary modules and classes
from __future__ import annotations
from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes
from ..arch import ( # Import architecture-related utilities and classes
TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch)
from ..roller.hint import Hint # Import the Hint class
from ..roller.node import OutputNode # Import the OutputNode class
from typing import List # For type hinting
from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
......@@ -24,10 +24,10 @@ class BaseTemplate(ABC):
_func: PrimFunc = field(default=None, init=False, repr=False)
# The outputs nodes associated with this template, initially None
_output_nodes: List[OutputNode] = field(default=None, init=False, repr=False)
_output_nodes: list[OutputNode] = field(default=None, init=False, repr=False)
@abstractmethod
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Abstract method that must be implemented by subclasses.
It should return a list of hardware-aware configurations (hints)
......@@ -42,7 +42,7 @@ class BaseTemplate(ABC):
"""
pass
def with_arch(self, arch: TileDevice) -> "BaseTemplate":
def with_arch(self, arch: TileDevice) -> BaseTemplate:
"""
Sets the architecture for this template and returns itself.
......@@ -110,7 +110,7 @@ class BaseTemplate(ABC):
"""
raise NotImplementedError("initialize_function is not implemented")
def set_function(self, func: PrimFunc) -> "BaseTemplate":
def set_function(self, func: PrimFunc) -> BaseTemplate:
"""
Sets the function for this template and returns itself.
......@@ -123,7 +123,7 @@ class BaseTemplate(ABC):
self._func = func
return self
def set_output_nodes(self, output_nodes: List[OutputNode]) -> "BaseTemplate":
def set_output_nodes(self, output_nodes: list[OutputNode]) -> BaseTemplate:
"""
Sets the output nodes for this template and returns itself.
......@@ -136,7 +136,7 @@ class BaseTemplate(ABC):
self._output_nodes = output_nodes
return self
def recommend_hints(self, topk: int = 10) -> List[Hint]:
def recommend_hints(self, topk: int = 10) -> list[Hint]:
"""
Provides a list of recommended hardware-aware configurations.
......@@ -159,7 +159,7 @@ class BaseTemplate(ABC):
return self._arch
@property
def output_nodes(self) -> List[OutputNode]:
def output_nodes(self) -> list[OutputNode]:
"""
Returns the output nodes associated with this template.
......
from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te, tir
from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func
......@@ -44,7 +44,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype: str = "float16" # Data type for accumulation
with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch=None, topk=10) -> List[Hint]:
def get_hardware_aware_configs(self, arch=None, topk=10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
......
# Import necessary modules
from __future__ import annotations
from dataclasses import dataclass # Used for defining data classes
from .base import BaseTemplate # Importing the base class for templates
from tvm import te # Importing TVM's tensor expression module
from ..arch import TileDevice # Importing TileDevice for hardware-specific configurations
from ..roller import Hint # Importing Hint for optimization hints
from typing import List # Importing List type hint
from ..utils import get_roller_hints_from_func # Function to obtain optimization hints
......@@ -19,10 +19,10 @@ class ElementwiseTemplate(BaseTemplate):
"""
# OP Related Config
shape: List[int] = None # Shape of the tensor
shape: list[int] = None # Shape of the tensor
dtype: str = "float16" # Data type of the tensor
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves hardware-aware optimization configurations.
......
from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from ..roller import PrimFuncNode, OutputNode, Edge
from typing import List
from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_tags
@dataclass
class FlashAttentionTemplate(BaseTemplate):
_output_nodes: List[OutputNode] = None
_output_nodes: list[OutputNode] = None
# Operation-related configuration parameters
batch_size: int = 1
......@@ -26,7 +26,7 @@ class FlashAttentionTemplate(BaseTemplate):
out_dtype: str = "float16"
accum_dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
......
from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func
......@@ -25,7 +25,7 @@ class GEMVTemplate(BaseTemplate):
accum_dtype: str = "float16" # Accumulation data type
with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
......
from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from typing import List, Union
from ..utils import get_roller_hints_from_func
......@@ -11,11 +11,11 @@ from ..utils import get_roller_hints_from_func
class GeneralReductionTemplate(BaseTemplate):
# OP Related Config
structure: Union[str, List[str]] = None
shape: List[int] = None
structure: str | list[str] = None
shape: list[int] = None
dtype: str = "float16"
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
roller_hints = get_roller_hints_from_func(
self._func, arch=arch, topk=topk, allow_gemv=False)
return roller_hints
......
from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from typing import List
from ..utils import get_roller_hints_from_func
......@@ -38,7 +38,7 @@ class MatmulTemplate(BaseTemplate):
accum_dtype: str = "float16" # Data type for accumulation
with_bias: bool = False # Whether to add a bias term
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
......
from typing import List, Optional, Union
from __future__ import annotations
from tvm import tir, IRModule
from tvm.tir import PrimFunc
from .arch import TileDevice
......@@ -26,11 +26,11 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
"""
def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
arch: TileDevice,
topk: int = 10,
tensorcore_only: bool = False,
allow_gemv: bool = False) -> Optional[List[Hint]]:
allow_gemv: bool = False) -> list[Hint] | None:
func = None
if isinstance(func_or_module, tir.PrimFunc):
func = func_or_module
......@@ -69,11 +69,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
return roller_hints
def get_roller_hints_from_output_nodes(
output_nodes: List[OutputNode],
arch: TileDevice,
topk: int = 10,
extra_tags: Optional[List[str]] = None) -> Optional[List[Hint]]:
def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
arch: TileDevice,
topk: int = 10,
extra_tags: list[str] | None = None) -> list[Hint] | None:
assert isinstance(output_nodes, list), "The input should be a list of functions."
lints = []
......
......@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Util to invoke C/C++ compilers in the system."""
from __future__ import annotations
import functools
import os
import shutil
......@@ -23,7 +24,6 @@ import platform
# pylint: disable=invalid-name
import sys
from typing import Dict
from tvm.base import py_str
from tvm.contrib import tar as _tar
......@@ -208,7 +208,7 @@ def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_e
raise ValueError("Unsupported platform")
def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]:
def get_global_symbol_section_map(path, *, nm=None) -> dict[str, str]:
"""Get global symbols from a library via nm -g
Parameters
......
......@@ -54,7 +54,7 @@ def compile_hip(code,
if target_format not in ["hsaco"]:
raise ValueError("target_format must be hsaco")
temp_code = temp.relpath("my_kernel.cc")
temp_target = temp.relpath("my_kernel.%s" % target_format)
temp_target = temp.relpath(f"my_kernel.{target_format}")
with open(temp_code, "w") as out_file:
out_file.write(code)
......
......@@ -2,11 +2,11 @@
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
from __future__ import absolute_import as _abs
from __future__ import annotations
import os
import subprocess
import warnings
from typing import Tuple
from tilelang.env import CUDA_HOME
import tvm.ffi
......@@ -299,7 +299,7 @@ def get_target_compute_version(target=None):
"Try specifying it by adding '-arch=sm_xx' to your target.")
def parse_compute_version(compute_version) -> Tuple[int, int]:
def parse_compute_version(compute_version) -> tuple[int, int]:
"""Parse compute capability string to divide major and minor version
Parameters
......
from __future__ import annotations
import cuda.bindings.nvrtc as nvrtc
from typing import Literal, Union, List, Optional, Tuple
from typing import Literal
from tvm.target import Target
from .nvcc import get_target_compute_version, parse_compute_version
def get_nvrtc_version() -> Tuple[int, int]:
def get_nvrtc_version() -> tuple[int, int]:
result, major, minor = nvrtc.nvrtcVersion()
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get NVRTC version: {result}"
return (major, minor)
......@@ -12,8 +13,8 @@ def get_nvrtc_version() -> Tuple[int, int]:
def compile_cuda(code: str,
target_format: Literal["ptx", "cubin"] = "ptx",
arch: Optional[int] = None,
options: Optional[Union[str, List[str]]] = None,
arch: int | None = None,
options: str | list[str] | None = None,
verbose: bool = False) -> bytearray:
"""Compile cuda code with NVRTC.
......
from typing import Callable, Union
from __future__ import annotations
from typing import Callable
from tvm import register_func
from tvm.target import Target
......@@ -25,7 +26,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
register_func("tilelang_callback_hip_postproc", f=func, override=override)
def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering CUDA post-processing callback function.
Can be used with or without parentheses:
......@@ -58,7 +59,7 @@ def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override
raise TypeError("Invalid decorator usage")
def register_hip_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
def register_hip_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering HIP post-processing callback function.
Can be used with or without parentheses:
......
"""The compiler for TL programs."""
from __future__ import annotations
import os
import os.path as osp
from typing import Union, Optional, Callable, List
from typing import Callable
import tilelang.transform
from tilelang import tvm as tvm
from tvm import tir
......@@ -114,7 +115,7 @@ def tilelang_callback_hip_compile(code, target):
return hsaco
def extrac_params(func: tir.PrimFunc) -> List[KernelParam]:
def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
tensor_types = []
for var in func.params:
if var in func.buffer_map:
......@@ -124,7 +125,7 @@ def extrac_params(func: tir.PrimFunc) -> List[KernelParam]:
return tensor_types
def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]):
def canon_target_host(target: str | Target, target_host: str | Target | None):
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"
......@@ -190,9 +191,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
def lower(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None,
func_or_mod: tir.PrimFunc | tvm.IRModule,
target: str | Target = "auto",
target_host: str | Target | None = None,
runtime_only=False,
enable_host_codegen=False,
enable_device_compile=False,
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Union, Optional
import torch
from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var, PrimExpr
......@@ -15,7 +15,7 @@ class KernelParam:
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
"""
dtype: torch.dtype # PyTorch data type of the parameter
shape: List[Union[int, Var]] # List of dimensions, can be integers or TVM variables
shape: list[int | Var] # List of dimensions, can be integers or TVM variables
@classmethod
def from_buffer(cls, buffer: Buffer):
......@@ -111,7 +111,6 @@ class CompiledArtifact:
"""
host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution
device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code
params: List[KernelParam] # List of parameters (tensors/scalars) used by the kernel
params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel
kernel_source: str # Raw source code of the generated kernel
rt_mod: Optional[
tvm.runtime.Module] = None # Runtime module for execution, may be lazily initialized
rt_mod: tvm.runtime.Module | None = None # Runtime module for execution, may be lazily initialized
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment