Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -51,7 +51,6 @@ def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc |
# Decorator to simplify the output of a function
def simplify_prim_func(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
stmt: PrimFunc | IRModule = (func)(*args, **kwargs)
return _Simplify(stmt)
......
def deprecated_warning(method_name: str, new_method_name: str, phaseout_version: str = None):
"""A function to indicate that a method is deprecated
"""
"""A function to indicate that a method is deprecated"""
import warnings # pylint: disable=import-outside-toplevel, import-error
warnings.warn(
f"{method_name} is deprecated, use {new_method_name} instead" +
(f" and will be removed in {phaseout_version}" if phaseout_version else ""),
f"{method_name} is deprecated, use {new_method_name} instead"
+ (f" and will be removed in {phaseout_version}" if phaseout_version else ""),
DeprecationWarning,
stacklevel=2,
)
......@@ -30,7 +29,6 @@ def deprecated(
import functools # pylint: disable=import-outside-toplevel
def _deprecate(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
deprecated_warning(method_name, new_method_name, phaseout_version)
......
......@@ -24,8 +24,7 @@ def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) ->
elif isinstance(buffer_or_load_or_region, (tir.BufferLoad, tir.BufferRegion)):
return buffer_or_load_or_region.buffer
else:
raise TypeError(
f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
raise TypeError(f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
def is_global(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
......@@ -153,14 +152,12 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
"""
if not isinstance(ir_module, IRModule):
raise ValueError("Not supported type: ", type(ir_module))
assert len(ir_module.get_global_vars()) == 1, (
"The optimized module should only have one global variable for default schedule.")
assert len(ir_module.get_global_vars()) == 1, "The optimized module should only have one global variable for default schedule."
func = list(ir_module.functions.values())[0]
return func
def get_buffer_region_from_load(buffer_load: tir.BufferLoad,
extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None:
def get_buffer_region_from_load(buffer_load: tir.BufferLoad, extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None:
"""
Get the buffer region from a buffer load.
......@@ -193,9 +190,9 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad,
return None
def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var,
access_type: str = "rw",
extents: list[PrimExpr] | None = None) -> PrimExpr | BufferRegion:
def to_buffer_region(
obj: Buffer | BufferLoad | BufferRegion | tir.Var, access_type: str = "rw", extents: list[PrimExpr] | None = None
) -> PrimExpr | BufferRegion:
"""
Convert to/from the tl.region representation.
......@@ -203,6 +200,7 @@ def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var,
- tl.region Call -> returns the decoded BufferRegion for analysis
"""
from tilelang.language.frame import has_let_value, get_let_value
if isinstance(obj, tir.Var) and has_let_value(obj):
obj = get_let_value(obj)
# Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis
......@@ -279,8 +277,7 @@ def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
return strides
def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
access_type: str = "r") -> PrimExpr:
def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion, access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
......
......@@ -15,7 +15,7 @@ os.makedirs(_CACHE_DIR, exist_ok=True)
def _get_cached_lib():
name = 'compress_lib'
name = "compress_lib"
if os.path.exists(os.path.join(_CACHE_DIR, f"{name}.so")):
try:
......@@ -32,24 +32,22 @@ def _get_cached_lib():
name=name,
sources=[compress_util],
extra_cuda_cflags=[
'-O2',
'-std=c++17',
'-lineinfo',
f'-I{env.CUTLASS_INCLUDE_DIR}',
f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include',
'-arch=sm_90',
"-O2",
"-std=c++17",
"-lineinfo",
f"-I{env.CUTLASS_INCLUDE_DIR}",
f"-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include",
"-arch=sm_90",
],
build_directory=_CACHE_DIR,
)
def compress_sm90(A: torch.Tensor, block_k: int,
transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
def compress_sm90(A: torch.Tensor, block_k: int, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
if block_k > 128:
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(
f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2)
warnings.warn(f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2)
# Load the library (will use cache if available)
compress_lib = _get_cached_lib()
......@@ -60,8 +58,9 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc
try:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
except ImportError as err:
raise ImportError("SparseSemiStructuredTensor is not available in this version of PyTorch. "
"Please install a compatible version.") from err
raise ImportError(
"SparseSemiStructuredTensor is not available in this version of PyTorch. Please install a compatible version."
) from err
orig_val = SparseSemiStructuredTensor._FORCE_CUTLASS
try:
SparseSemiStructuredTensor._FORCE_CUTLASS = True
......@@ -73,10 +72,7 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc
SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val
def compress(A: torch.Tensor,
transposed: bool,
arch: str | None = None,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
def compress(A: torch.Tensor, transposed: bool, arch: str | None = None, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compress a tensor using the appropriate method based on the CUDA architecture.
"""
......@@ -101,11 +97,10 @@ def compress(A: torch.Tensor,
A_sp = A_sp.t().contiguous()
return A_sp, E
else:
raise ValueError(f"Unsupported CUDA compute version: {compute_version}. "
"Supported versions are sm_80 and sm_90.")
raise ValueError(f"Unsupported CUDA compute version: {compute_version}. Supported versions are sm_80 and sm_90.")
def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transposed: bool = False):
def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False):
"""
Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
......@@ -127,13 +122,7 @@ def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transp
return tensor.to(dtype) # dtype like float8 might not have randn kernel
def randint_semi_sparse(M: int,
K: int,
low: int,
high: int,
dtype=torch.int32,
device='cuda',
transposed: bool = False):
def randint_semi_sparse(M: int, K: int, low: int, high: int, dtype=torch.int32, device="cuda", transposed: bool = False):
"""
Generate a random semi-sparse integer tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
......@@ -157,11 +146,7 @@ def randint_semi_sparse(M: int,
return tensor
def arange_semi_sparse(M: int,
K: int,
dtype=torch.float16,
device='cuda',
transposed: bool = False):
def arange_semi_sparse(M: int, K: int, dtype=torch.float16, device="cuda", transposed: bool = False):
"""
Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
......
......@@ -56,11 +56,10 @@ def check_metal_availability() -> bool:
if not mac_release:
return False
# todo: check torch version?
return arch == 'arm64'
return arch == "arm64"
def determine_target(target: str | Target | Literal["auto"] = "auto",
return_object: bool = False) -> str | Target:
def determine_target(target: str | Target | Literal["auto"] = "auto", return_object: bool = False) -> str | Target:
"""
Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
......
"""The profiler and convert to torch utils"""
from enum import Enum
import torch
from tvm import tir
......@@ -17,7 +18,7 @@ def is_float8_dtype(dtype: torch.dtype) -> bool:
def fp8_remove_negative_zeros_(tensor: torch.Tensor):
assert is_float8_dtype(tensor.dtype), "Input tensor must be of float8 dtype"
bits = tensor.view(torch.uint8)
zeros_mask = (tensor == 0)
zeros_mask = tensor == 0
bits[zeros_mask] = 0x00
......@@ -33,26 +34,21 @@ class TensorSupplyType(Enum):
def map_torch_type(intype: str) -> torch.dtype:
if intype == "float8_e4m3":
assert hasattr(torch, "float8_e4m3fn"), \
"torch.float8_e4m3fn is not supported in this version of torch" \
"Please upgrade torch >= 2.1.0"
assert hasattr(torch, "float8_e4m3fn"), "torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0"
return torch.float8_e4m3fn
elif intype == "float8_e5m2":
assert hasattr(torch, "float8_e5m2"), \
"torch.float8_e5m2 is not supported in this version of torch" \
"Please upgrade torch >= 2.1.0"
assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torchPlease upgrade torch >= 2.1.0"
return torch.float8_e5m2
elif intype == "e4m3fnuz_float8":
assert hasattr(torch, "float8_e4m3fnuz"), \
"torch.float8_e4m3fnuz is not supported in this version of torch" \
"Please upgrade torch >= 2.2.0"
assert hasattr(torch, "float8_e4m3fnuz"), (
"torch.float8_e4m3fnuz is not supported in this version of torchPlease upgrade torch >= 2.2.0"
)
return torch.float8_e4m3fnuz
else:
return getattr(torch, intype)
def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
from tilelang.engine.param import KernelParam
from .device import get_current_device
......@@ -63,7 +59,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if hasattr(param, "shape") and not param.shape:
raise ValueError(
f"TensorType must have a shape, but got {type(param)}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape.")
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
# Check if with dynamic symbolic shape
for shape in param.shape:
......@@ -81,8 +78,7 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8:
return torch.randint(
low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif is_boolean:
return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype)
elif dtype in {torch.float16, torch.float32, torch.bfloat16}:
......@@ -103,18 +99,15 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8:
return torch.randint(
low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif is_boolean:
return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype)
else:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform:
return torch.empty(
*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype)
return torch.empty(*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype)
elif supply_type == TensorSupplyType.Normal:
return torch.empty(
*shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype)
return torch.empty(*shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype)
elif supply_type == TensorSupplyType.Randn:
return torch.randn(*shape, device=device).to(dtype)
elif supply_type == TensorSupplyType.Zero:
......@@ -150,9 +143,7 @@ def _compare_attributes(
"""
def raise_mismatch_error(attribute_name: str, actual_value, expected_value):
raise AssertionError(
f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}."
)
raise AssertionError(f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.")
if actual.shape != expected.shape:
raise_mismatch_error("shape", actual.shape, expected.shape)
......@@ -163,7 +154,7 @@ def _compare_attributes(
if actual.layout != expected.layout:
if check_layout:
raise_mismatch_error("layout", actual.layout, expected.layout)
elif (actual.layout == torch.strided and check_stride and actual.stride() != expected.stride()):
elif actual.layout == torch.strided and check_stride and actual.stride() != expected.stride():
raise_mismatch_error("stride()", actual.stride(), expected.stride())
if check_device and actual.device != expected.device:
raise_mismatch_error("device", actual.device, expected.device)
......@@ -171,8 +162,7 @@ def _compare_attributes(
raise_mismatch_error("dtype", actual.dtype, expected.dtype)
def _equalize_attributes(actual: torch.Tensor,
expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Equalizes some attributes of two tensors for value comparison.
If ``actual`` and ``expected`` are ...
- ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
......@@ -210,7 +200,7 @@ def _equalize_attributes(actual: torch.Tensor,
if actual.layout != expected.layout:
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
actual = actual.to_dense() if actual.layout != torch.strided else actual
expected = (expected.to_dense() if expected.layout != torch.strided else expected)
expected = expected.to_dense() if expected.layout != torch.strided else expected
return actual, expected
......@@ -254,12 +244,8 @@ def torch_assert_close(
"""
_compare_attributes(
tensor_a,
tensor_b,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
check_stride=check_stride)
tensor_a, tensor_b, check_device=check_device, check_dtype=check_dtype, check_layout=check_layout, check_stride=check_stride
)
tensor_a, tensor_b = _equalize_attributes(tensor_a, tensor_b)
mismatched = ~torch.isclose(tensor_a, tensor_b, rtol=rtol, atol=atol, equal_nan=equal_nan)
......@@ -276,8 +262,7 @@ def torch_assert_close(
# Print debug information about the mismatch
if verbose:
print(f"Number of mismatched elements: {num_mismatched} / {total_elements} "
f"(allowed: {max_allowed_mismatched})")
print(f"Number of mismatched elements: {num_mismatched} / {total_elements} (allowed: {max_allowed_mismatched})")
# If there are mismatched elements, print the first mismatch
if num_mismatched > 0:
......@@ -289,9 +274,9 @@ def torch_assert_close(
b_val = tensor_b.reshape(-1)[flat_idx].item()
abs_diff = abs(a_val - b_val)
rel_diff = abs_diff / (abs(b_val) + 1e-12)
mismatch_info = (f"\nFirst mismatch at index {idx}: "
f"lhs={a_val:.6f}, rhs={b_val:.6f}, "
f"abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}")
mismatch_info = (
f"\nFirst mismatch at index {idx}: lhs={a_val:.6f}, rhs={b_val:.6f}, abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}"
)
else:
mismatch_info = ""
......@@ -304,6 +289,7 @@ def torch_assert_close(
f"\nGreatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}"
f"\n{base_name}: {tensor_a}"
f"\n{ref_name}: {tensor_b}")
f"\n{ref_name}: {tensor_b}"
)
else:
return True
......@@ -8,29 +8,26 @@ from functools import lru_cache
ROOT = Path(__file__).parent
base_version = (ROOT / 'VERSION').read_text().strip()
base_version = (ROOT / "VERSION").read_text().strip()
# When installing a sdist,
# the installed version needs to match the sdist version,
# so pip will complain when we install `tilelang-0.1.6.post2+gitxxxx.tar.gz`.
# To workaround that, when building sdist,
# we do not add version label and use a file to store the git hash instead.
git_pin = ROOT / '.git_commit.txt'
git_pin = ROOT / ".git_commit.txt"
def _read_cmake_bool(i: str | None, default=False):
if i is None:
return default
return i.lower() not in ('0', 'false', 'off', 'no', 'n', '')
return i.lower() not in ("0", "false", "off", "no", "n", "")
@lru_cache(maxsize=1)
def get_git_commit_id() -> str | None:
"""Get the current git commit hash by running git in the current file's directory."""
r = subprocess.run(['git', 'rev-parse', 'HEAD'],
cwd=ROOT,
capture_output=True,
encoding='utf-8')
r = subprocess.run(["git", "rev-parse", "HEAD"], cwd=ROOT, capture_output=True, encoding="utf-8")
if r.returncode == 0:
_git = r.stdout.strip()
git_pin.write_text(_git)
......@@ -41,51 +38,48 @@ def get_git_commit_id() -> str | None:
return None
def dynamic_metadata(
field: str,
settings: dict[str, object] | None = None,
) -> str:
assert field == 'version'
def dynamic_metadata(field: str, settings: dict[str, object] | None = None) -> str:
assert field == "version"
version = base_version
# generate git version for sdist
get_git_commit_id()
if not _read_cmake_bool(os.environ.get('NO_VERSION_LABEL')):
if not _read_cmake_bool(os.environ.get("NO_VERSION_LABEL")):
exts = []
backend = None
if _read_cmake_bool(os.environ.get('NO_TOOLCHAIN_VERSION')):
if _read_cmake_bool(os.environ.get("NO_TOOLCHAIN_VERSION")):
pass
elif platform.system() == 'Darwin':
elif platform.system() == "Darwin":
# only on macosx_11_0_arm64, not necessary
# backend = 'metal'
pass
elif _read_cmake_bool(os.environ.get('USE_ROCM', '')):
backend = 'rocm'
elif 'USE_CUDA' in os.environ and not _read_cmake_bool(os.environ.get('USE_CUDA')):
backend = 'cpu'
elif _read_cmake_bool(os.environ.get("USE_ROCM", "")):
backend = "rocm"
elif "USE_CUDA" in os.environ and not _read_cmake_bool(os.environ.get("USE_CUDA")):
backend = "cpu"
else: # cuda
# Read nvcc version from env.
# This is not exactly how it should be,
# but works for now if building in a nvidia/cuda image.
if cuda_version := os.environ.get('CUDA_VERSION'):
major, minor, *_ = cuda_version.split('.')
backend = f'cu{major}{minor}'
if cuda_version := os.environ.get("CUDA_VERSION"):
major, minor, *_ = cuda_version.split(".")
backend = f"cu{major}{minor}"
else:
backend = 'cuda'
backend = "cuda"
if backend:
exts.append(backend)
if _read_cmake_bool(os.environ.get('NO_GIT_VERSION')):
if _read_cmake_bool(os.environ.get("NO_GIT_VERSION")):
pass
elif git_hash := get_git_commit_id():
exts.append(f'git{git_hash[:8]}')
exts.append(f"git{git_hash[:8]}")
else:
exts.append('gitunknown')
exts.append("gitunknown")
if exts:
version += '+' + '.'.join(exts)
version += "+" + ".".join(exts)
return version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment