"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "131302bc93291e78f54afb5aaf5d8d7b90bc0149"
Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
"""The language interface for tl programs.""" """The language interface for tl programs."""
from __future__ import annotations
from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm.ffi import register_object from tvm.ffi import register_object
from tilelang import _ffi_api from tilelang import _ffi_api
from .kernel import get_thread_bindings, get_thread_extents from .kernel import get_thread_bindings, get_thread_extents
from typing import List
@register_object("tl.WarpSpecializeFrame") @register_object("tl.WarpSpecializeFrame")
...@@ -45,7 +45,7 @@ def WarpSpecialize(*warp_group_idx): ...@@ -45,7 +45,7 @@ def WarpSpecialize(*warp_group_idx):
# only available for nvidia gpus. # only available for nvidia gpus.
warp_group_size = 128 warp_group_size = 128
warp_group_ids: List[int] = [] warp_group_ids: list[int] = []
for warp_group_id in warp_group_idx: for warp_group_id in warp_group_idx:
warp_group_ids.append(warp_group_id) warp_group_ids.append(warp_group_id)
......
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm import tvm
from tvm.ir import Range from tvm.ir import Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api from tilelang import _ffi_api
from tilelang.layout import Layout from tilelang.layout import Layout
from typing import List
@tvm.ffi.register_object("tl.Fragment") @tvm.ffi.register_object("tl.Fragment")
...@@ -123,7 +123,7 @@ class Fragment(Layout): ...@@ -123,7 +123,7 @@ class Fragment(Layout):
def repeat(self, def repeat(self,
repeats, repeats,
repeat_on_thread: bool = False, repeat_on_thread: bool = False,
lower_dim_first: bool = True) -> "Fragment": lower_dim_first: bool = True) -> Fragment:
""" """
Returns a new Fragment that repeats the iteration space a given number of times. Returns a new Fragment that repeats the iteration space a given number of times.
...@@ -143,7 +143,7 @@ class Fragment(Layout): ...@@ -143,7 +143,7 @@ class Fragment(Layout):
""" """
return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first) return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first)
def replicate(self, replicate: int) -> "Fragment": def replicate(self, replicate: int) -> Fragment:
""" """
Replicate the Fragment across a new thread dimension. Replicate the Fragment across a new thread dimension.
...@@ -159,7 +159,7 @@ class Fragment(Layout): ...@@ -159,7 +159,7 @@ class Fragment(Layout):
""" """
return _ffi_api.Fragment_replicate(self, replicate) return _ffi_api.Fragment_replicate(self, replicate)
def condense_rep_var(self) -> "Fragment": def condense_rep_var(self) -> Fragment:
""" """
Condense or fold the replicate variable into the existing iteration space. Condense or fold the replicate variable into the existing iteration space.
This operation may be used to reduce dimensionality if the replicate variable This operation may be used to reduce dimensionality if the replicate variable
...@@ -172,7 +172,7 @@ class Fragment(Layout): ...@@ -172,7 +172,7 @@ class Fragment(Layout):
""" """
return _ffi_api.Fragment_condense_rep_var(self) return _ffi_api.Fragment_condense_rep_var(self)
def map_forward_thread(self, indices: List[PrimExpr]) -> PrimExpr: def map_forward_thread(self, indices: list[PrimExpr]) -> PrimExpr:
""" """
Get the thread mapping expression for a given set of argument indices. Get the thread mapping expression for a given set of argument indices.
...@@ -206,7 +206,7 @@ class Fragment(Layout): ...@@ -206,7 +206,7 @@ class Fragment(Layout):
""" """
return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
def is_equal(self, other: "Fragment") -> bool: def is_equal(self, other: Fragment) -> bool:
""" """
Check if the current fragment is equal to another fragment. Check if the current fragment is equal to another fragment.
""" """
......
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
from typing import Optional
import tvm import tvm
import tilelang.language as T import tilelang.language as T
import warnings import warnings
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
from typing import List
from math import prod from math import prod
def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]: def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]:
res = [] res = []
for x in basis: for x in basis:
res.append(index_1d % x) res.append(index_1d % x)
...@@ -136,7 +135,7 @@ def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str): ...@@ -136,7 +135,7 @@ def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str):
def make_metadata_layout(buffer: tvm.tir.Buffer, def make_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16", mma_dtype: str = "float16",
backend: str = 'cutlass', backend: str = 'cutlass',
arch: Optional[str] = None, arch: str | None = None,
**extra_args): **extra_args):
if arch is None: if arch is None:
arch = nvcc.get_target_compute_version() arch = nvcc.get_target_compute_version()
......
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm import tvm
from tvm.ir import Node, Range from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api from tilelang import _ffi_api
from typing import List
# Register the Layout class as a TVM object under the name "tl.Layout" # Register the Layout class as a TVM object under the name "tl.Layout"
...@@ -92,7 +92,7 @@ class Layout(Node): ...@@ -92,7 +92,7 @@ class Layout(Node):
def get_forward_index(self): def get_forward_index(self):
return self.index return self.index
def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr: def map_forward_index(self, indices: list[PrimExpr]) -> PrimExpr:
""" """
Compute the forward index mapping for a given set of input indices. Compute the forward index mapping for a given set of input indices.
...@@ -122,7 +122,7 @@ class Layout(Node): ...@@ -122,7 +122,7 @@ class Layout(Node):
# Map the provided indices using the constructed index mapping # Map the provided indices using the constructed index mapping
return index_map.map_indices(indices) return index_map.map_indices(indices)
def inverse(self) -> "Layout": def inverse(self) -> Layout:
""" """
Compute the inverse of the current layout transformation. Compute the inverse of the current layout transformation.
...@@ -133,7 +133,7 @@ class Layout(Node): ...@@ -133,7 +133,7 @@ class Layout(Node):
""" """
return _ffi_api.Layout_inverse(self) return _ffi_api.Layout_inverse(self)
def is_equal(self, other: "Layout") -> bool: def is_equal(self, other: Layout) -> bool:
""" """
Check if the current layout is equal to another layout. Check if the current layout is equal to another layout.
......
from typing import Optional from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.utils import is_local, is_fragment, is_shared from tilelang.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
...@@ -12,11 +13,11 @@ def gemm( ...@@ -12,11 +13,11 @@ def gemm(
C: tir.Buffer, C: tir.Buffer,
transpose_A: bool = False, transpose_A: bool = False,
transpose_B: bool = False, transpose_B: bool = False,
block_row_warps: Optional[int] = None, block_row_warps: int | None = None,
block_col_warps: Optional[int] = None, block_col_warps: int | None = None,
warp_row_tiles: Optional[int] = None, warp_row_tiles: int | None = None,
warp_col_tiles: Optional[int] = None, warp_col_tiles: int | None = None,
chunk: Optional[int] = None, chunk: int | None = None,
policy: GemmWarpPolicy = GemmWarpPolicy.Square, policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1, k_pack: int = 1,
): ):
......
from __future__ import annotations
from enum import IntEnum from enum import IntEnum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from tvm import tir from tvm import tir
...@@ -161,7 +161,7 @@ class GemmWarpPolicy(IntEnum): ...@@ -161,7 +161,7 @@ class GemmWarpPolicy(IntEnum):
return m_warp, n_warp return m_warp, n_warp
@classmethod @classmethod
def from_warp_partition(cls, m_warp: int, n_warp: int) -> 'GemmWarpPolicy': def from_warp_partition(cls, m_warp: int, n_warp: int) -> GemmWarpPolicy:
""" """
Determine the warp policy based on the given warp partitioning. Determine the warp policy based on the given warp partitioning.
...@@ -197,11 +197,11 @@ class GemmBaseParams: ...@@ -197,11 +197,11 @@ class GemmBaseParams:
transpose_A: bool = False transpose_A: bool = False
transpose_B: bool = False transpose_B: bool = False
block_row_warps: Optional[int] = None block_row_warps: int | None = None
block_col_warps: Optional[int] = None block_col_warps: int | None = None
warp_row_tiles: Optional[int] = None warp_row_tiles: int | None = None
warp_col_tiles: Optional[int] = None warp_col_tiles: int | None = None
chunk: Optional[int] = None chunk: int | None = None
policy: GemmWarpPolicy = GemmWarpPolicy.Square, policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1 k_pack: int = 1
...@@ -226,7 +226,7 @@ class GemmBaseParams: ...@@ -226,7 +226,7 @@ class GemmBaseParams:
"k_pack": self.k_pack, "k_pack": self.k_pack,
} }
def infer_block_partition(self, threads: Optional[int]) -> None: def infer_block_partition(self, threads: int | None) -> None:
""" """
Infer and set block partition parameters (e.g., block_row_warps, Infer and set block partition parameters (e.g., block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from __future__ import annotations
from typing import List, Optional, Callable, Any, Literal from typing import Callable, Any, Literal
from functools import partial from functools import partial
import torch import torch
from contextlib import suppress from contextlib import suppress
...@@ -28,17 +29,17 @@ class Profiler: ...@@ -28,17 +29,17 @@ class Profiler:
adapter: Optional kernel adapter for interfacing with different backends adapter: Optional kernel adapter for interfacing with different backends
""" """
params: List[KernelParam] params: list[KernelParam]
result_idx: List[int] result_idx: list[int]
supply_type: TensorSupplyType supply_type: TensorSupplyType
adapter: Optional[BaseKernelAdapter] = None adapter: BaseKernelAdapter | None = None
def __post_init__(self): def __post_init__(self):
"""Initialize tensor supply after dataclass initialization""" """Initialize tensor supply after dataclass initialization"""
self.result_idx = self._legalize_result_idx(self.result_idx) self.result_idx = self._legalize_result_idx(self.result_idx)
self.supply = get_tensor_supply(self.supply_type) self.supply = get_tensor_supply(self.supply_type)
def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]: def _legalize_result_idx(self, result_idx: list[int] | None = None) -> list[int]:
params = self.params params = self.params
# result_idx is a list of indices of the output tensors # result_idx is a list of indices of the output tensors
if result_idx is None: if result_idx is None:
...@@ -55,7 +56,7 @@ class Profiler: ...@@ -55,7 +56,7 @@ class Profiler:
return result_idx return result_idx
def with_default_adapter(self, adapter: BaseKernelAdapter) -> "Profiler": def with_default_adapter(self, adapter: BaseKernelAdapter) -> Profiler:
self.adapter = adapter self.adapter = adapter
return self return self
...@@ -76,7 +77,7 @@ class Profiler: ...@@ -76,7 +77,7 @@ class Profiler:
def assert_allclose( def assert_allclose(
self, self,
reference_program: Callable, reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None, input_tensors: list[torch.Tensor] | None = None,
atol: float = 1e-2, atol: float = 1e-2,
rtol: float = 1e-2, rtol: float = 1e-2,
max_mismatched_ratio=0.01, max_mismatched_ratio=0.01,
...@@ -147,7 +148,7 @@ class Profiler: ...@@ -147,7 +148,7 @@ class Profiler:
def manual_assert_close( def manual_assert_close(
self, self,
reference_program: Callable, reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None, input_tensors: list[torch.Tensor] | None = None,
manual_check_prog: Callable = None, manual_check_prog: Callable = None,
): ):
"""Validates kernel output against a reference implementation. """Validates kernel output against a reference implementation.
...@@ -194,13 +195,13 @@ class Profiler: ...@@ -194,13 +195,13 @@ class Profiler:
rhs, rhs,
] ]
def run_once(self, func: Optional[Callable] = None): def run_once(self, func: Callable | None = None):
ins = self._get_inputs() ins = self._get_inputs()
if not func: if not func:
func = self.__call__ func = self.__call__
return func(*ins) return func(*ins)
def determine_profiler(self, func: Optional[Callable] = None): def determine_profiler(self, func: Callable | None = None):
"""Determines which profiler backend to use based on function type. """Determines which profiler backend to use based on function type.
Args: Args:
...@@ -217,14 +218,14 @@ class Profiler: ...@@ -217,14 +218,14 @@ class Profiler:
def do_bench( def do_bench(
self, self,
func: Optional[Callable] = None, func: Callable | None = None,
warmup: int = 25, warmup: int = 25,
rep: int = 100, rep: int = 100,
n_warmup: int = 1, n_warmup: int = 1,
n_repeat: int = 1, n_repeat: int = 1,
input_tensors: List[torch.Tensor] = None, input_tensors: list[torch.Tensor] = None,
backend: Literal["event", "cupti"] = "event", backend: Literal["event", "cupti"] = "event",
quantiles: Optional[List[float]] = None, quantiles: list[float] | None = None,
return_mode: Literal["min", "max", "mean", "median"] = "mean", return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> float: ) -> float:
"""Benchmarks the execution time of a given function. """Benchmarks the execution time of a given function.
......
"""Profiler and benchmarking utilities for PyTorch functions.""" """Profiler and benchmarking utilities for PyTorch functions."""
from __future__ import annotations
import os import os
import sys import sys
from typing import Callable, List, Literal, Optional, Union from typing import Callable, Literal
import torch import torch
...@@ -65,11 +66,11 @@ def do_bench( ...@@ -65,11 +66,11 @@ def do_bench(
rep: float = 100, rep: float = 100,
_n_warmup: int = 0, _n_warmup: int = 0,
_n_repeat: int = 0, _n_repeat: int = 0,
quantiles: Optional[List[float]] = None, quantiles: list[float] | None = None,
fast_flush: bool = True, fast_flush: bool = True,
backend: Literal["event", "cupti"] = "event", backend: Literal["event", "cupti"] = "event",
return_mode: Literal["min", "max", "mean", "median"] = "mean", return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> Union[float, List[float]]: ) -> float | list[float]:
"""Benchmark the runtime of a PyTorch function with L2 cache management. """Benchmark the runtime of a PyTorch function with L2 cache management.
This function provides accurate GPU kernel timing by: This function provides accurate GPU kernel timing by:
...@@ -138,9 +139,9 @@ def _bench_with_cuda_events( ...@@ -138,9 +139,9 @@ def _bench_with_cuda_events(
fn: Callable, fn: Callable,
cache: torch.Tensor, cache: torch.Tensor,
n_repeat: int, n_repeat: int,
quantiles: Optional[List[float]], quantiles: list[float] | None,
return_mode: str, return_mode: str,
) -> Union[float, List[float]]: ) -> float | list[float]:
"""Benchmark using CUDA events for timing.""" """Benchmark using CUDA events for timing."""
# Create timing events # Create timing events
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
from typing import Dict, Literal from __future__ import annotations
from typing import Literal
decode_i4_to_f16 = """ decode_i4_to_f16 = """
template <typename T1, typename T2, bool isSigned = false> template <typename T1, typename T2, bool isSigned = false>
...@@ -1096,7 +1097,7 @@ def get_lop3_intrin_group( ...@@ -1096,7 +1097,7 @@ def get_lop3_intrin_group(
with_zeros: bool = False, with_zeros: bool = False,
zeros_mode: Literal["original", "rescale", "quantized"] = "original", zeros_mode: Literal["original", "rescale", "quantized"] = "original",
storage_scope: str = "local", storage_scope: str = "local",
) -> Dict[str, str]: ) -> dict[str, str]:
""" """
This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding.
LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
...@@ -1186,9 +1187,9 @@ def get_lop3_intrin_group( ...@@ -1186,9 +1187,9 @@ def get_lop3_intrin_group(
elif out_dtype == "int4": elif out_dtype == "int4":
d4f = "i4s" d4f = "i4s"
else: else:
raise ValueError("Unsupported target dtype: {}".format(target_dtype)) raise ValueError(f"Unsupported target dtype: {target_dtype}")
source_symbol = "u" if source_format == "uint" else "s" source_symbol = "u" if source_format == "uint" else "s"
func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f) func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}"
if with_scaling: if with_scaling:
func_name += "_scale" func_name += "_scale"
if with_zeros: if with_zeros:
......
from typing import Literal, Dict from __future__ import annotations
from typing import Literal
# Implementation asm for fp4 to bf16, using twiddling # Implementation asm for fp4 to bf16, using twiddling
# Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18 # Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18
...@@ -54,7 +55,7 @@ def get_mxfp_intrin_group( ...@@ -54,7 +55,7 @@ def get_mxfp_intrin_group(
source_bit: int = 4, source_bit: int = 4,
storage_dtype: Literal["int32", "int8", "uint8"] = "uint8", storage_dtype: Literal["int32", "int8", "uint8"] = "uint8",
use_twiddling: bool = False, use_twiddling: bool = False,
) -> Dict[str, str]: ) -> dict[str, str]:
""" """
Return metadata for an MXFP decoding intrinsic: function name and C source string. Return metadata for an MXFP decoding intrinsic: function name and C source string.
......
...@@ -223,7 +223,7 @@ def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): ...@@ -223,7 +223,7 @@ def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
e4 = val & tir.const(0x40, "uint16") e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"),
tir.const(0x4000, "uint16")) tir.const(0x4000, "uint16"))
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | prefix
return tir.reinterpret("float16", s_f16 | e_f16) return tir.reinterpret("float16", s_f16 | e_f16)
...@@ -232,7 +232,7 @@ def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): ...@@ -232,7 +232,7 @@ def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert dtype == "float16" assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16") e4 = val & tir.const(0x40, "uint16")
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16") e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16) return tir.reinterpret("float16", s_f16 | e_f16)
......
...@@ -9,7 +9,7 @@ from tvm.ir import PrimExpr ...@@ -9,7 +9,7 @@ from tvm.ir import PrimExpr
@dataclass @dataclass
class GemmBase(object): class GemmBase:
gemm_node: Node gemm_node: Node
def infer_layout(self, target: Target, thread_nums: int): def infer_layout(self, target: Target, thread_nums: int):
......
from __future__ import annotations
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from tilelang import tvm from tilelang import tvm
from tvm.tir.stmt_functor import ir_transform from tvm.tir.stmt_functor import ir_transform
import logging import logging
from typing import Optional
# Configuration for different hardware architectures. # Configuration for different hardware architectures.
# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count) # Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count)
ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)} ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)}
...@@ -168,7 +168,7 @@ class Analyzer: ...@@ -168,7 +168,7 @@ class Analyzer:
AnalysisResult: The calculated performance metrics. AnalysisResult: The calculated performance metrics.
""" """
def get_peak_tflops(device) -> Optional[float]: def get_peak_tflops(device) -> float | None:
""" """
Get the peak TFLOPS for the target device. Get the peak TFLOPS for the target device.
Args: Args:
......
from __future__ import annotations
from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm) from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm)
from tvm.tir.stmt_functor import ir_transform, post_order_visit from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass from tvm.tir.transform import prim_func_pass
from typing import Tuple, List, Dict
def AddWrapperForSingleBufStore(): def AddWrapperForSingleBufStore():
...@@ -42,7 +42,7 @@ def AddWrapperForSingleBufStore(): ...@@ -42,7 +42,7 @@ def AddWrapperForSingleBufStore():
post_order_visit(operation, visit_variable) post_order_visit(operation, visit_variable)
return used_variables return used_variables
def collect_buffer_accesses(statement) -> Tuple[List[Buffer], List[Buffer]]: def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]:
""" """
Categorizes buffers accessed in the statement by their scope. Categorizes buffers accessed in the statement by their scope.
...@@ -69,7 +69,7 @@ def AddWrapperForSingleBufStore(): ...@@ -69,7 +69,7 @@ def AddWrapperForSingleBufStore():
local_buffers.append(buffer) local_buffers.append(buffer)
return local_buffers, fragment_buffers return local_buffers, fragment_buffers
def collect_buffer_indices(statement) -> Dict[Buffer, List[int]]: def collect_buffer_indices(statement) -> dict[Buffer, list[int]]:
""" """
Maps each buffer to its access indices. Maps each buffer to its access indices.
......
from __future__ import annotations
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule from tvm import IRModule
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from typing import Union, Callable from typing import Callable
from . import _ffi_api from . import _ffi_api
...@@ -27,8 +28,7 @@ def Simplify(simplify_arguments: bool = False): ...@@ -27,8 +28,7 @@ def Simplify(simplify_arguments: bool = False):
return _ffi_api.Simplify(simplify_arguments) # type: ignore return _ffi_api.Simplify(simplify_arguments) # type: ignore
def _Simplify(stmt: Union[PrimFunc, IRModule], def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule:
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc): if isinstance(stmt, PrimFunc):
if inline_let: if inline_let:
mod = LetInline()(IRModule.from_expr(stmt)) mod = LetInline()(IRModule.from_expr(stmt))
...@@ -53,13 +53,12 @@ def _Simplify(stmt: Union[PrimFunc, IRModule], ...@@ -53,13 +53,12 @@ def _Simplify(stmt: Union[PrimFunc, IRModule],
def simplify_prim_func(func: Callable) -> Callable: def simplify_prim_func(func: Callable) -> Callable:
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs) stmt: PrimFunc | IRModule = (func)(*args, **kwargs)
return _Simplify(stmt) return _Simplify(stmt)
return wrapper return wrapper
def apply_simplify(stmt: Union[PrimFunc, IRModule], def apply_simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule:
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
"""Apply Simplify pass to a PrimFunc or IRModule.""" """Apply Simplify pass to a PrimFunc or IRModule."""
return _Simplify(stmt, inline_let) return _Simplify(stmt, inline_let)
from __future__ import annotations
from tvm.tir import Buffer from tvm.tir import Buffer
from typing import List, Optional
from functools import reduce from functools import reduce
from tvm import IRModule from tvm import IRModule
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
...@@ -85,7 +85,7 @@ def get_buffer_elems(buffer: Buffer) -> int: ...@@ -85,7 +85,7 @@ def get_buffer_elems(buffer: Buffer) -> int:
return reduce(lambda x, y: x * y, buffer.shape) return reduce(lambda x, y: x * y, buffer.shape)
def array_reduce(array: List[int]) -> int: def array_reduce(array: list[int]) -> int:
""" """
Reduce an array of integers to a single integer. Reduce an array of integers to a single integer.
...@@ -121,7 +121,7 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: ...@@ -121,7 +121,7 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
return func return func
def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.BufferRegion]: def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion | None:
""" """
Get the buffer region from a buffer load. Get the buffer region from a buffer load.
......
from __future__ import annotations
import os import os
import torch import torch
import warnings import warnings
from typing import Optional, Tuple
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
from torch.utils.cpp_extension import load, _import_module_from_library from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang import env from tilelang import env
...@@ -44,7 +44,7 @@ def _get_cached_lib(): ...@@ -44,7 +44,7 @@ def _get_cached_lib():
def compress_sm90(A: torch.Tensor, block_k: int, def compress_sm90(A: torch.Tensor, block_k: int,
transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
if block_k > 128: if block_k > 128:
block_k = 128 block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
...@@ -56,7 +56,7 @@ def compress_sm90(A: torch.Tensor, block_k: int, ...@@ -56,7 +56,7 @@ def compress_sm90(A: torch.Tensor, block_k: int,
return compress_lib.compress_sm90(A, block_k, transposed) return compress_lib.compress_sm90(A, block_k, transposed)
def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
try: try:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
except ImportError as err: except ImportError as err:
...@@ -75,8 +75,8 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torc ...@@ -75,8 +75,8 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torc
def compress(A: torch.Tensor, def compress(A: torch.Tensor,
transposed: bool, transposed: bool,
arch: Optional[str] = None, arch: str | None = None,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]: **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Compress a tensor using the appropriate method based on the CUDA architecture. Compress a tensor using the appropriate method based on the CUDA architecture.
""" """
......
from __future__ import annotations
from platform import mac_ver from platform import mac_ver
from typing import Dict, Literal, Union from typing import Literal
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang import _ffi_api from tilelang import _ffi_api
from tvm.target import Target from tvm.target import Target
from tvm.contrib import rocm from tvm.contrib import rocm
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
SUPPORTED_TARGETS: Dict[str, str] = { SUPPORTED_TARGETS: dict[str, str] = {
"auto": "Auto-detect CUDA/HIP/Metal based on availability.", "auto": "Auto-detect CUDA/HIP/Metal based on availability.",
"cuda": "CUDA GPU target (supports options such as `cuda -arch=sm_80`).", "cuda": "CUDA GPU target (supports options such as `cuda -arch=sm_80`).",
"hip": "ROCm HIP target (supports options like `hip -mcpu=gfx90a`).", "hip": "ROCm HIP target (supports options like `hip -mcpu=gfx90a`).",
...@@ -17,7 +18,7 @@ SUPPORTED_TARGETS: Dict[str, str] = { ...@@ -17,7 +18,7 @@ SUPPORTED_TARGETS: Dict[str, str] = {
} }
def describe_supported_targets() -> Dict[str, str]: def describe_supported_targets() -> dict[str, str]:
""" """
Return a mapping of supported target names to usage descriptions. Return a mapping of supported target names to usage descriptions.
""" """
...@@ -58,8 +59,8 @@ def check_metal_availability() -> bool: ...@@ -58,8 +59,8 @@ def check_metal_availability() -> bool:
return arch == 'arm64' return arch == 'arm64'
def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", def determine_target(target: str | Target | Literal["auto"] = "auto",
return_object: bool = False) -> Union[str, Target]: return_object: bool = False) -> str | Target:
""" """
Determine the appropriate target for compilation (CUDA, HIP, or manual selection). Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
...@@ -76,7 +77,7 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", ...@@ -76,7 +77,7 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
AssertionError: If the target is invalid. AssertionError: If the target is invalid.
""" """
return_var: Union[str, Target] = target return_var: str | Target = target
if target == "auto": if target == "auto":
target = tvm.target.Target.current(allow_none=True) target = tvm.target.Target.current(allow_none=True)
......
...@@ -3,7 +3,6 @@ from __future__ import annotations ...@@ -3,7 +3,6 @@ from __future__ import annotations
import os import os
import platform import platform
import subprocess import subprocess
from typing import Optional
from pathlib import Path from pathlib import Path
ROOT = Path(__file__).parent ROOT = Path(__file__).parent
...@@ -17,13 +16,12 @@ def _read_cmake_bool(i: str | None, default=False): ...@@ -17,13 +16,12 @@ def _read_cmake_bool(i: str | None, default=False):
return i.lower() not in ('0', 'false', 'off', 'no', 'n', '') return i.lower() not in ('0', 'false', 'off', 'no', 'n', '')
def get_git_commit_id() -> Optional[str]: def get_git_commit_id() -> str | None:
"""Get the current git commit hash by running git in the current file's directory.""" """Get the current git commit hash by running git in the current file's directory."""
r = subprocess.run(['git', 'rev-parse', 'HEAD'], r = subprocess.run(['git', 'rev-parse', 'HEAD'],
cwd=ROOT, cwd=ROOT,
stdout=subprocess.PIPE, capture_output=True,
stderr=subprocess.PIPE,
encoding='utf-8') encoding='utf-8')
if r.returncode == 0: if r.returncode == 0:
return r.stdout.strip() return r.stdout.strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment