Unverified Commit f14fb111 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

[Lint] Enable pyupgrade linter in ruff (#963)

* update rules

* ruff check

* other fixes

* fmt

* do not touch examples

* fmt
parent 4f3523dc
"""The language interface for tl programs."""
from __future__ import annotations
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm.ffi import register_object
from tilelang import _ffi_api
from .kernel import get_thread_bindings, get_thread_extents
from typing import List
@register_object("tl.WarpSpecializeFrame")
......@@ -45,7 +45,7 @@ def WarpSpecialize(*warp_group_idx):
# only available for nvidia gpus.
warp_group_size = 128
warp_group_ids: List[int] = []
warp_group_ids: list[int] = []
for warp_group_id in warp_group_idx:
warp_group_ids.append(warp_group_id)
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm
from tvm.ir import Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
from tilelang.layout import Layout
from typing import List
@tvm.ffi.register_object("tl.Fragment")
......@@ -123,7 +123,7 @@ class Fragment(Layout):
def repeat(self,
repeats,
repeat_on_thread: bool = False,
lower_dim_first: bool = True) -> "Fragment":
lower_dim_first: bool = True) -> Fragment:
"""
Returns a new Fragment that repeats the iteration space a given number of times.
......@@ -143,7 +143,7 @@ class Fragment(Layout):
"""
return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first)
def replicate(self, replicate: int) -> "Fragment":
def replicate(self, replicate: int) -> Fragment:
"""
Replicate the Fragment across a new thread dimension.
......@@ -159,7 +159,7 @@ class Fragment(Layout):
"""
return _ffi_api.Fragment_replicate(self, replicate)
def condense_rep_var(self) -> "Fragment":
def condense_rep_var(self) -> Fragment:
"""
Condense or fold the replicate variable into the existing iteration space.
This operation may be used to reduce dimensionality if the replicate variable
......@@ -172,7 +172,7 @@ class Fragment(Layout):
"""
return _ffi_api.Fragment_condense_rep_var(self)
def map_forward_thread(self, indices: List[PrimExpr]) -> PrimExpr:
def map_forward_thread(self, indices: list[PrimExpr]) -> PrimExpr:
"""
Get the thread mapping expression for a given set of argument indices.
......@@ -206,7 +206,7 @@ class Fragment(Layout):
"""
return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
def is_equal(self, other: "Fragment") -> bool:
def is_equal(self, other: Fragment) -> bool:
"""
Check if the current fragment is equal to another fragment.
"""
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
from typing import Optional
import tvm
import tilelang.language as T
import warnings
from tilelang.contrib import nvcc
from typing import List
from math import prod
def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]:
def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]:
res = []
for x in basis:
res.append(index_1d % x)
......@@ -136,7 +135,7 @@ def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str):
def make_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16",
backend: str = 'cutlass',
arch: Optional[str] = None,
arch: str | None = None,
**extra_args):
if arch is None:
arch = nvcc.get_target_compute_version()
......
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm
from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
from typing import List
# Register the Layout class as a TVM object under the name "tl.Layout"
......@@ -92,7 +92,7 @@ class Layout(Node):
def get_forward_index(self):
return self.index
def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr:
def map_forward_index(self, indices: list[PrimExpr]) -> PrimExpr:
"""
Compute the forward index mapping for a given set of input indices.
......@@ -122,7 +122,7 @@ class Layout(Node):
# Map the provided indices using the constructed index mapping
return index_map.map_indices(indices)
def inverse(self) -> "Layout":
def inverse(self) -> Layout:
"""
Compute the inverse of the current layout transformation.
......@@ -133,7 +133,7 @@ class Layout(Node):
"""
return _ffi_api.Layout_inverse(self)
def is_equal(self, other: "Layout") -> bool:
def is_equal(self, other: Layout) -> bool:
"""
Check if the current layout is equal to another layout.
......
from typing import Optional
from __future__ import annotations
from tvm import tir
from tilelang.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy
......@@ -12,11 +13,11 @@ def gemm(
C: tir.Buffer,
transpose_A: bool = False,
transpose_B: bool = False,
block_row_warps: Optional[int] = None,
block_col_warps: Optional[int] = None,
warp_row_tiles: Optional[int] = None,
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
block_row_warps: int | None = None,
block_col_warps: int | None = None,
warp_row_tiles: int | None = None,
warp_col_tiles: int | None = None,
chunk: int | None = None,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1,
):
......
from __future__ import annotations
from enum import IntEnum
from dataclasses import dataclass
from typing import Optional
from tvm import tir
......@@ -161,7 +161,7 @@ class GemmWarpPolicy(IntEnum):
return m_warp, n_warp
@classmethod
def from_warp_partition(cls, m_warp: int, n_warp: int) -> 'GemmWarpPolicy':
def from_warp_partition(cls, m_warp: int, n_warp: int) -> GemmWarpPolicy:
"""
Determine the warp policy based on the given warp partitioning.
......@@ -197,11 +197,11 @@ class GemmBaseParams:
transpose_A: bool = False
transpose_B: bool = False
block_row_warps: Optional[int] = None
block_col_warps: Optional[int] = None
warp_row_tiles: Optional[int] = None
warp_col_tiles: Optional[int] = None
chunk: Optional[int] = None
block_row_warps: int | None = None
block_col_warps: int | None = None
warp_row_tiles: int | None = None
warp_col_tiles: int | None = None
chunk: int | None = None
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1
......@@ -226,7 +226,7 @@ class GemmBaseParams:
"k_pack": self.k_pack,
}
def infer_block_partition(self, threads: Optional[int]) -> None:
def infer_block_partition(self, threads: int | None) -> None:
"""
Infer and set block partition parameters (e.g., block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
from typing import List, Optional, Callable, Any, Literal
from typing import Callable, Any, Literal
from functools import partial
import torch
from contextlib import suppress
......@@ -28,17 +29,17 @@ class Profiler:
adapter: Optional kernel adapter for interfacing with different backends
"""
params: List[KernelParam]
result_idx: List[int]
params: list[KernelParam]
result_idx: list[int]
supply_type: TensorSupplyType
adapter: Optional[BaseKernelAdapter] = None
adapter: BaseKernelAdapter | None = None
def __post_init__(self):
"""Initialize tensor supply after dataclass initialization"""
self.result_idx = self._legalize_result_idx(self.result_idx)
self.supply = get_tensor_supply(self.supply_type)
def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]:
def _legalize_result_idx(self, result_idx: list[int] | None = None) -> list[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
......@@ -55,7 +56,7 @@ class Profiler:
return result_idx
def with_default_adapter(self, adapter: BaseKernelAdapter) -> "Profiler":
def with_default_adapter(self, adapter: BaseKernelAdapter) -> Profiler:
self.adapter = adapter
return self
......@@ -76,7 +77,7 @@ class Profiler:
def assert_allclose(
self,
reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None,
input_tensors: list[torch.Tensor] | None = None,
atol: float = 1e-2,
rtol: float = 1e-2,
max_mismatched_ratio=0.01,
......@@ -147,7 +148,7 @@ class Profiler:
def manual_assert_close(
self,
reference_program: Callable,
input_tensors: Optional[List[torch.Tensor]] = None,
input_tensors: list[torch.Tensor] | None = None,
manual_check_prog: Callable = None,
):
"""Validates kernel output against a reference implementation.
......@@ -194,13 +195,13 @@ class Profiler:
rhs,
]
def run_once(self, func: Optional[Callable] = None):
def run_once(self, func: Callable | None = None):
ins = self._get_inputs()
if not func:
func = self.__call__
return func(*ins)
def determine_profiler(self, func: Optional[Callable] = None):
def determine_profiler(self, func: Callable | None = None):
"""Determines which profiler backend to use based on function type.
Args:
......@@ -217,14 +218,14 @@ class Profiler:
def do_bench(
self,
func: Optional[Callable] = None,
func: Callable | None = None,
warmup: int = 25,
rep: int = 100,
n_warmup: int = 1,
n_repeat: int = 1,
input_tensors: List[torch.Tensor] = None,
input_tensors: list[torch.Tensor] = None,
backend: Literal["event", "cupti"] = "event",
quantiles: Optional[List[float]] = None,
quantiles: list[float] | None = None,
return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> float:
"""Benchmarks the execution time of a given function.
......
"""Profiler and benchmarking utilities for PyTorch functions."""
from __future__ import annotations
import os
import sys
from typing import Callable, List, Literal, Optional, Union
from typing import Callable, Literal
import torch
......@@ -65,11 +66,11 @@ def do_bench(
rep: float = 100,
_n_warmup: int = 0,
_n_repeat: int = 0,
quantiles: Optional[List[float]] = None,
quantiles: list[float] | None = None,
fast_flush: bool = True,
backend: Literal["event", "cupti"] = "event",
return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> Union[float, List[float]]:
) -> float | list[float]:
"""Benchmark the runtime of a PyTorch function with L2 cache management.
This function provides accurate GPU kernel timing by:
......@@ -138,9 +139,9 @@ def _bench_with_cuda_events(
fn: Callable,
cache: torch.Tensor,
n_repeat: int,
quantiles: Optional[List[float]],
quantiles: list[float] | None,
return_mode: str,
) -> Union[float, List[float]]:
) -> float | list[float]:
"""Benchmark using CUDA events for timing."""
# Create timing events
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, Literal
from __future__ import annotations
from typing import Literal
decode_i4_to_f16 = """
template <typename T1, typename T2, bool isSigned = false>
......@@ -1096,7 +1097,7 @@ def get_lop3_intrin_group(
with_zeros: bool = False,
zeros_mode: Literal["original", "rescale", "quantized"] = "original",
storage_scope: str = "local",
) -> Dict[str, str]:
) -> dict[str, str]:
"""
This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding.
LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
......@@ -1186,9 +1187,9 @@ def get_lop3_intrin_group(
elif out_dtype == "int4":
d4f = "i4s"
else:
raise ValueError("Unsupported target dtype: {}".format(target_dtype))
raise ValueError(f"Unsupported target dtype: {target_dtype}")
source_symbol = "u" if source_format == "uint" else "s"
func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f)
func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}"
if with_scaling:
func_name += "_scale"
if with_zeros:
......
from typing import Literal, Dict
from __future__ import annotations
from typing import Literal
# Implementation asm for fp4 to bf16, using twiddling
# Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18
......@@ -54,7 +55,7 @@ def get_mxfp_intrin_group(
source_bit: int = 4,
storage_dtype: Literal["int32", "int8", "uint8"] = "uint8",
use_twiddling: bool = False,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Return metadata for an MXFP decoding intrinsic: function name and C source string.
......
......@@ -223,7 +223,7 @@ def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"),
tir.const(0x4000, "uint16"))
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix
e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)
......@@ -232,7 +232,7 @@ def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16)
......
......@@ -9,7 +9,7 @@ from tvm.ir import PrimExpr
@dataclass
class GemmBase(object):
class GemmBase:
gemm_node: Node
def infer_layout(self, target: Target, thread_nums: int):
......
from __future__ import annotations
import numpy as np
from dataclasses import dataclass
from tilelang import tvm
from tvm.tir.stmt_functor import ir_transform
import logging
from typing import Optional
# Configuration for different hardware architectures.
# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count)
ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)}
......@@ -168,7 +168,7 @@ class Analyzer:
AnalysisResult: The calculated performance metrics.
"""
def get_peak_tflops(device) -> Optional[float]:
def get_peak_tflops(device) -> float | None:
"""
Get the peak TFLOPS for the target device.
Args:
......
from __future__ import annotations
from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm)
from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass
from typing import Tuple, List, Dict
def AddWrapperForSingleBufStore():
......@@ -42,7 +42,7 @@ def AddWrapperForSingleBufStore():
post_order_visit(operation, visit_variable)
return used_variables
def collect_buffer_accesses(statement) -> Tuple[List[Buffer], List[Buffer]]:
def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]:
"""
Categorizes buffers accessed in the statement by their scope.
......@@ -69,7 +69,7 @@ def AddWrapperForSingleBufStore():
local_buffers.append(buffer)
return local_buffers, fragment_buffers
def collect_buffer_indices(statement) -> Dict[Buffer, List[int]]:
def collect_buffer_indices(statement) -> dict[Buffer, list[int]]:
"""
Maps each buffer to its access indices.
......
from __future__ import annotations
from tilelang import tvm as tvm
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Union, Callable
from typing import Callable
from . import _ffi_api
......@@ -27,8 +28,7 @@ def Simplify(simplify_arguments: bool = False):
return _ffi_api.Simplify(simplify_arguments) # type: ignore
def _Simplify(stmt: Union[PrimFunc, IRModule],
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule:
if isinstance(stmt, PrimFunc):
if inline_let:
mod = LetInline()(IRModule.from_expr(stmt))
......@@ -53,13 +53,12 @@ def _Simplify(stmt: Union[PrimFunc, IRModule],
def simplify_prim_func(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
stmt: PrimFunc | IRModule = (func)(*args, **kwargs)
return _Simplify(stmt)
return wrapper
def apply_simplify(stmt: Union[PrimFunc, IRModule],
inline_let: bool = False) -> Union[PrimFunc, IRModule]:
def apply_simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule:
"""Apply Simplify pass to a PrimFunc or IRModule."""
return _Simplify(stmt, inline_let)
from __future__ import annotations
from tvm.tir import Buffer
from typing import List, Optional
from functools import reduce
from tvm import IRModule
from tvm.tir import PrimFunc
......@@ -85,7 +85,7 @@ def get_buffer_elems(buffer: Buffer) -> int:
return reduce(lambda x, y: x * y, buffer.shape)
def array_reduce(array: List[int]) -> int:
def array_reduce(array: list[int]) -> int:
"""
Reduce an array of integers to a single integer.
......@@ -121,7 +121,7 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
return func
def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.BufferRegion]:
def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion | None:
"""
Get the buffer region from a buffer load.
......
from __future__ import annotations
import os
import torch
import warnings
from typing import Optional, Tuple
from tilelang.contrib import nvcc
from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang import env
......@@ -44,7 +44,7 @@ def _get_cached_lib():
def compress_sm90(A: torch.Tensor, block_k: int,
transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
if block_k > 128:
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
......@@ -56,7 +56,7 @@ def compress_sm90(A: torch.Tensor, block_k: int,
return compress_lib.compress_sm90(A, block_k, transposed)
def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
try:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
except ImportError as err:
......@@ -75,8 +75,8 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torc
def compress(A: torch.Tensor,
transposed: bool,
arch: Optional[str] = None,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
arch: str | None = None,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compress a tensor using the appropriate method based on the CUDA architecture.
"""
......
from __future__ import annotations
from platform import mac_ver
from typing import Dict, Literal, Union
from typing import Literal
from tilelang import tvm as tvm
from tilelang import _ffi_api
from tvm.target import Target
from tvm.contrib import rocm
from tilelang.contrib import nvcc
SUPPORTED_TARGETS: Dict[str, str] = {
SUPPORTED_TARGETS: dict[str, str] = {
"auto": "Auto-detect CUDA/HIP/Metal based on availability.",
"cuda": "CUDA GPU target (supports options such as `cuda -arch=sm_80`).",
"hip": "ROCm HIP target (supports options like `hip -mcpu=gfx90a`).",
......@@ -17,7 +18,7 @@ SUPPORTED_TARGETS: Dict[str, str] = {
}
def describe_supported_targets() -> Dict[str, str]:
def describe_supported_targets() -> dict[str, str]:
"""
Return a mapping of supported target names to usage descriptions.
"""
......@@ -58,8 +59,8 @@ def check_metal_availability() -> bool:
return arch == 'arm64'
def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
return_object: bool = False) -> Union[str, Target]:
def determine_target(target: str | Target | Literal["auto"] = "auto",
return_object: bool = False) -> str | Target:
"""
Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
......@@ -76,7 +77,7 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
AssertionError: If the target is invalid.
"""
return_var: Union[str, Target] = target
return_var: str | Target = target
if target == "auto":
target = tvm.target.Target.current(allow_none=True)
......
......@@ -3,7 +3,6 @@ from __future__ import annotations
import os
import platform
import subprocess
from typing import Optional
from pathlib import Path
ROOT = Path(__file__).parent
......@@ -17,13 +16,12 @@ def _read_cmake_bool(i: str | None, default=False):
return i.lower() not in ('0', 'false', 'off', 'no', 'n', '')
def get_git_commit_id() -> Optional[str]:
def get_git_commit_id() -> str | None:
"""Get the current git commit hash by running git in the current file's directory."""
r = subprocess.run(['git', 'rev-parse', 'HEAD'],
cwd=ROOT,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
capture_output=True,
encoding='utf-8')
if r.returncode == 0:
return r.stdout.strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment