Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -31,7 +31,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
) -> tir.PrimExpr:
in_dtype = self.in_dtype
warp_cols = mma_emitter.warp_cols
local_size_b = mma_emitter.local_size_b
......@@ -53,21 +52,24 @@ class GemmPrimitiveMMA(GemmBaseParams):
if a_is_fragment:
# Annotate layout for A_local if it is a fragment.
T.annotate_layout({
T.annotate_layout(
{
A_local: mma_emitter.make_mma_load_layout(A_local, "A"),
})
}
)
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout({
T.annotate_layout(
{
C_local: mma_emitter.make_mma_store_layout(C_local),
})
}
)
# Make default swizzle layout for shared memory
# T.annotate_layout({
# B_shared: make_mma_swizzle_layout(B_shared),
# })
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
......@@ -146,9 +148,11 @@ class GemmPrimitiveMMA(GemmBaseParams):
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout({
T.annotate_layout(
{
C_local: mma_emitter.make_mma_store_layout(C_local),
})
}
)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
......
"""The profiler and convert to torch utils"""
from __future__ import annotations
from typing import Callable, Any, Literal
from functools import partial
......@@ -45,8 +46,7 @@ class Profiler:
result_idx = []
elif isinstance(result_idx, int):
if result_idx > len(params) or result_idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
raise ValueError(f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
......@@ -113,8 +113,7 @@ class Profiler:
ref_tensors = ins + ref_outs
lib_tensors = ins + lib_outs
assert len(lib_tensors) == len(
ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !"
assert len(lib_tensors) == len(ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !"
# torch.set_printoptions(edgeitems=torch.inf)
for lhs, rhs in zip(lib_tensors, ref_tensors):
# close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
......@@ -252,10 +251,9 @@ class Profiler:
)
elif profiler == "tvm":
assert func is not None, "func should not be None"
assert isinstance(
func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}"
assert isinstance(func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}"
ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors)
ins = self._get_inputs(with_output=True) if input_tensors is None else input_tensors
target = "cuda"
with suppress(Exception):
......@@ -264,8 +262,7 @@ class Profiler:
assert target in ["cuda", "hip"], f"Unknown target: {target}"
device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0)
time_evaluator = self.mod.time_evaluator(
self.mod.entry_name, device, number=rep, repeat=n_repeat)
time_evaluator = self.mod.time_evaluator(self.mod.entry_name, device, number=rep, repeat=n_repeat)
# Transform Latency to ms
return time_evaluator(*ins).mean * 1e3
else:
......
"""Profiler and benchmarking utilities for PyTorch functions."""
from __future__ import annotations
import os
......@@ -16,8 +17,8 @@ class suppress_stdout_stderr:
def __enter__(self):
# Open null device files
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
# Save original file descriptors
self.old_stdout_fileno_undup = sys.stdout.fileno()
......@@ -56,7 +57,7 @@ class suppress_stdout_stderr:
IS_CUDA = torch.cuda.is_available()
device = 'cuda:0' if IS_CUDA else 'mps:0'
device = "cuda:0" if IS_CUDA else "mps:0"
Event = torch.cuda.Event if IS_CUDA else torch.mps.Event
......@@ -93,8 +94,7 @@ def do_bench(
Returns:
Runtime in milliseconds (float) or list of quantile values if quantiles specified
"""
assert return_mode in ["min", "max", "mean", "median"], \
f"Invalid return_mode: {return_mode}"
assert return_mode in ["min", "max", "mean", "median"], f"Invalid return_mode: {return_mode}"
# Initial function call and synchronization
fn()
......
......@@ -1130,16 +1130,13 @@ def get_lop3_intrin_group(
Dict[str, str]
A dictionary mapping the names of the intrinsics to their corresponding implementations.
"""
assert out_dtype in [
"float16", "int8", "int4"
], (f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' .")
assert out_dtype in ["float16", "int8", "int4"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ."
dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"}
target_dtype = dtype_mapping[out_dtype]
if source_format not in ["int", "uint"]:
raise ValueError(
f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.")
raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.")
if with_zeros and source_format == "int":
raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}")
......
......@@ -80,13 +80,9 @@ def get_mxfp_intrin_group(
AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation.
"""
assert out_dtype in ["float16", "bfloat16"
], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
assert source_format in ["int", "uint"
], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'."
assert storage_dtype in [
"int32", "int8", "uint8"
], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'."
assert out_dtype in ["float16", "bfloat16"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
assert source_format in ["int", "uint"], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'."
assert storage_dtype in ["int32", "int8", "uint8"], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'."
dtype_map = {"float16": "f16", "bfloat16": "bf16"}
key = f"fp{source_bit}_to_{dtype_map[out_dtype]}"
......
def gen_quant4(k, n, groupsize=-1):
import torch
import torch.nn as nn
maxq = 2**4
w = torch.randn((k, n), dtype=torch.half, device="cpu")
......@@ -48,6 +49,7 @@ def gen_quant4(k, n, groupsize=-1):
def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None):
import torch
if storage_dtype is None:
storage_dtype = torch.int8
elems_per_byte = 8 // source_bits
......@@ -56,11 +58,11 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None):
int8_weight = torch.zeros(
(*lowprecision_weight.shape[:-1], lowprecision_weight.shape[-1] // elems_per_byte),
dtype=torch.int8,
device=lowprecision_weight.device)
device=lowprecision_weight.device,
)
for j in range(lowprecision_weight.shape[-1] // elems_per_byte):
for k in range(elems_per_byte):
int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] <<
(source_bits * k)).to(torch.int8)
int8_weight[..., j] |= (lowprecision_weight[..., j * elems_per_byte + k] << (source_bits * k)).to(torch.int8)
return int8_weight.to(storage_dtype)
......@@ -82,6 +84,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
interleave_weight(qweight, 4, "float16")
"""
import torch
assert target_dtype in ["float16", "int8"]
# reinterpret the data type of qweight to int32
qweight = qweight.view(torch.int32)
......
......@@ -5,20 +5,19 @@ import random
import torch
import numpy as np
from tilelang.contrib import nvcc
from tvm.testing.utils import (requires_cuda, requires_package, requires_llvm, requires_metal,
requires_rocm, _compose)
from tvm.testing.utils import requires_cuda, requires_package, requires_llvm, requires_metal, requires_rocm, _compose
from tilelang.utils.tensor import torch_assert_close as torch_assert_close
__all__ = [
'requires_package',
'requires_cuda',
'requires_metal',
'requires_rocm',
'requires_llvm',
'main',
'requires_cuda_compute_version',
] + [f'requires_cuda_compute_version_{op}' for op in ('ge', 'gt', 'le', 'lt', 'eq')]
"requires_package",
"requires_cuda",
"requires_metal",
"requires_rocm",
"requires_llvm",
"main",
"requires_cuda_compute_version",
] + [f"requires_cuda_compute_version_{op}" for op in ("ge", "gt", "le", "lt", "eq")]
# pytest.main() wrapper to allow running single test file
......
......@@ -23,8 +23,7 @@ def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range)
@tvm_ffi.register_global_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range,
thread_var: tir.Var):
def gemm_py_lower(gemm_py: GemmMMA, layout_map, target: Target, thread_bounds: Range, thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
return stmt
......
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
MatrixCoreIntrinEmitter,
)
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
......@@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify
class GemmMFMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
......@@ -56,12 +55,10 @@ class GemmMFMA(GemmBase):
self.C: mfma_emitter.make_mfma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mfma_emitter = MatrixCoreIntrinEmitter(
......@@ -153,7 +150,6 @@ class GemmMFMA(GemmBase):
T.clear(C_buf)
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
......@@ -183,7 +179,6 @@ class GemmMFMA(GemmBase):
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
......@@ -217,8 +212,7 @@ class GemmMFMA(GemmBase):
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
......
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
......@@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify
class GemmMMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
......@@ -54,12 +53,10 @@ class GemmMMA(GemmBase):
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
......@@ -177,7 +174,6 @@ class GemmMMA(GemmBase):
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
......@@ -211,8 +207,7 @@ class GemmMMA(GemmBase):
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rrr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
......
......@@ -2,7 +2,8 @@
from .gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
......@@ -12,10 +13,8 @@ from tilelang.transform.simplify import _Simplify
class GemmMMASm70(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
......@@ -45,12 +44,10 @@ class GemmMMASm70(GemmBase):
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
......@@ -140,7 +137,6 @@ class GemmMMASm70(GemmBase):
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
......@@ -155,8 +151,7 @@ class GemmMMASm70(GemmBase):
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
......
from .gemm_base import GemmBase
from tilelang.layout import make_tcgen05mma_swizzled_layout
from tilelang.intrinsics.tcgen05_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
from tvm import tir
......@@ -18,10 +19,8 @@ _FLOAT8_DTYPES = {
class GemmTCGEN5(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
......@@ -40,27 +39,20 @@ class GemmTCGEN5(GemmBase):
b_is_k_major = self.trans_B
if self.is_gemm_ss():
a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp
b_continuity = self.K if b_is_k_major else self.N // n_warp
return {
# WGMMA does not support padding
self.A:
make_tcgen05mma_swizzled_layout(
self.A, continuity=a_continuity, k_major=a_is_k_major),
self.B:
make_tcgen05mma_swizzled_layout(
self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C:
mma_emitter.make_mma_store_layout(self.C),
self.A: make_tcgen05mma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major),
self.B: make_tcgen05mma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
# No special swizzle requirement; rely on existing layout.
return {}
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
......@@ -82,11 +74,9 @@ class GemmTCGEN5(GemmBase):
mma_emitter._assign_b_shared_layout(layout_map[self.B])
if not self.is_gemm_ss():
raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got "
f"A scope {self.A.scope()}, B scope {self.B.scope()}")
raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got A scope {self.A.scope()}, B scope {self.B.scope()}")
atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(
self.M, self.N, self.K)
atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K)
if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}:
raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}")
......@@ -108,7 +98,7 @@ class GemmTCGEN5(GemmBase):
raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access")
accum_dtype = str(self.C.dtype)
if accum_dtype not in ["float32", 'float16']:
if accum_dtype not in ["float32", "float16"]:
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.ARegion
......
from .gemm_base import GemmBase
from tilelang.layout import make_wgmma_swizzled_layout
from tilelang.intrinsics.wgmma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
......@@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify
class GemmWGMMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
......@@ -38,33 +37,22 @@ class GemmWGMMA(GemmBase):
return {
# WGMMA does not support padding
self.A:
make_wgmma_swizzled_layout(
self.A, continuity=a_continuity, k_major=a_is_k_major),
self.B:
make_wgmma_swizzled_layout(
self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C:
mma_emitter.make_mma_store_layout(self.C),
self.A: make_wgmma_swizzled_layout(self.A, continuity=a_continuity, k_major=a_is_k_major),
self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp
return {
self.A:
mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B:
make_wgmma_swizzled_layout(
self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C:
mma_emitter.make_mma_store_layout(self.C),
self.A: mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B: make_wgmma_swizzled_layout(self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
......@@ -133,8 +121,7 @@ class GemmWGMMA(GemmBase):
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
raise ValueError(
f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
......
from tilelang import tvm as tvm
from tvm import tir
from tilelang.utils.target import (
target_is_cuda,)
target_is_cuda,
)
from tvm.target import Target
from tvm.ir.base import Node
from tvm.ir import Range
......@@ -18,8 +19,7 @@ def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds
@tvm_ffi.register_global_func("tl.gemm_sp_py.lower")
def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range,
thread_var: tir.Var):
def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
return stmt
......
......@@ -10,10 +10,8 @@ from tilelang.transform.simplify import _Simplify
class GemmSPMMA(GemmSPBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = SparseTensorCoreIntrinEmitter(
......@@ -55,12 +53,10 @@ class GemmSPMMA(GemmSPBase):
self.C: mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False)
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = SparseTensorCoreIntrinEmitter(
......@@ -146,7 +142,6 @@ class GemmSPMMA(GemmSPBase):
E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
for ki in T.serial(0, (self.K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......@@ -231,8 +226,7 @@ class GemmSPMMA(GemmSPBase):
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rrr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
......
......@@ -4,6 +4,7 @@ from dataclasses import dataclass
from tilelang import tvm
from tvm.tir.stmt_functor import ir_transform
import logging
# Configuration for different hardware architectures.
# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count)
ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)}
......@@ -23,6 +24,7 @@ class AnalysisResult:
tflops: Achieved TFLOPS (trillions of FLOPs per second).
bandwidth_GBps: Achieved memory bandwidth in GB/s.
"""
total_flops: int
total_global_bytes: int
estimated_time: float
......@@ -81,7 +83,7 @@ class Analyzer:
# Account for loop and block dimensions
loop_product = 1
for extent in self.loop_stack:
loop_product *= extent.value if hasattr(extent, 'value') else extent
loop_product *= extent.value if hasattr(extent, "value") else extent
total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"]
total_bytes = bytes_transferred * loop_product * total_blocks
self.total_global_bytes += total_bytes
......@@ -100,7 +102,7 @@ class Analyzer:
# Account for loop and block dimensions
loop_product = 1
for extent in self.loop_stack:
loop_product *= extent.value if hasattr(extent, 'value') else extent
loop_product *= extent.value if hasattr(extent, "value") else extent
total_blocks = self.block_counts["blockIdx.x"] * self.block_counts["blockIdx.y"]
self.total_flops += flops_per_call * loop_product * total_blocks
......@@ -127,8 +129,7 @@ class Analyzer:
iter_var = stmt.node
thread_tag = iter_var.thread_tag
if thread_tag in self.block_counts:
extent = stmt.value.value if hasattr(stmt.value,
'value') else stmt.value
extent = stmt.value.value if hasattr(stmt.value, "value") else stmt.value
self.block_counts[thread_tag] = extent
elif isinstance(stmt, tvm.tir.For):
# Push loop extent onto the stack
......@@ -178,9 +179,7 @@ class Analyzer:
"""
arch_key = device.compute_capability[:2]
if arch_key not in ARCH_CONFIGS:
logger.info(
f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None"
)
logger.info(f"Unsupported compute capability: {device.compute_capability}, theoretical peak tflops will be None")
return None
cores_per_sm, default_clock, flops_per_cycle, compute_max_core = ARCH_CONFIGS[arch_key]
......@@ -203,7 +202,8 @@ class Analyzer:
total_global_bytes=self.total_global_bytes,
estimated_time=estimated_time,
expected_tflops=peak_tflops,
expected_bandwidth_GBps=bandwidth_GBps)
expected_bandwidth_GBps=bandwidth_GBps,
)
@classmethod
def analysis(cls, fn, device):
......
......@@ -2,12 +2,14 @@ from __future__ import annotations
import tilelang.language as T
def plot_layout(layout: T.Fragment,
def plot_layout(
layout: T.Fragment,
save_directory="./tmp",
name: str = "layout",
colormap: str = "RdPu",
verbose: bool = False,
formats: str | list[str] = "png") -> None:
formats: str | list[str] = "png",
) -> None:
"""
Plot the layout of a buffer.
......@@ -90,11 +92,13 @@ def plot_layout(layout: T.Fragment,
# Warn if the number of threads is less than the warp size
if num_threads < warp_size:
import warnings
warnings.warn(
f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). "
f"For the best viewing experience, it is recommended to have at least {warp_size} threads.",
UserWarning,
stacklevel=2)
stacklevel=2,
)
spectral_camp = plt.get_cmap("hsv", warp_size * 6)
for i in range(min(warp_size, num_threads)):
......@@ -118,12 +122,7 @@ def plot_layout(layout: T.Fragment,
color = colors[thread_ids[0]] # Select color based on thread ID
# Create a rectangle patch for visualization
rect = patches.Rectangle((j, i),
1,
1,
linewidth=0.5,
edgecolor='black',
facecolor=color)
rect = patches.Rectangle((j, i), 1, 1, linewidth=0.5, edgecolor="black", facecolor=color)
ax.add_patch(rect) # Add the rectangle to the plot
# Add text annotations inside the rectangles
......@@ -139,41 +138,19 @@ def plot_layout(layout: T.Fragment,
thread_fontsize = min(font_size, font_size * (4 / len(thread_str)))
# Add thread ID text with adjusted font size
ax.text(
j + 0.5,
i + 0.3,
thread_str,
ha='center',
va='center',
color='black',
fontsize=thread_fontsize)
ax.text(j + 0.5, i + 0.3, thread_str, ha="center", va="center", color="black", fontsize=thread_fontsize)
# Add local ID text with original font size
ax.text(
j + 0.5,
i + 0.7,
f"L{local_id}",
ha='center',
va='center',
color='black',
fontsize=font_size)
ax.text(j + 0.5, i + 0.7, f"L{local_id}", ha="center", va="center", color="black", fontsize=font_size)
# Add row labels to the left side of the plot
for i in range(nrows):
text = f"row {i}"
ax.text(-0.75, i + 0.5, text, ha='center', va='center', color='black', fontsize=font_size)
ax.text(-0.75, i + 0.5, text, ha="center", va="center", color="black", fontsize=font_size)
# Add column labels at the top of the plot
for j in range(ncols):
text = f"col {j}"
ax.text(
j + 0.5,
-0.5,
text,
ha='center',
va='center',
color='black',
fontsize=font_size,
rotation=45)
ax.text(j + 0.5, -0.5, text, ha="center", va="center", color="black", fontsize=font_size, rotation=45)
# Set the plot limits
ax.set_xlim(0, ncols)
......@@ -189,17 +166,15 @@ def plot_layout(layout: T.Fragment,
legend_x = 1.0 + (0.5 / fig_width) # Adjust x position based on figure width
legend_y = 1.0 + (1.7 / fig_height) # Adjust y position based on figure height
legend_patches = [
patches.Patch(color='black', label="T: Thread ID"),
patches.Patch(color='black', label="L: Local ID")
]
legend_patches = [patches.Patch(color="black", label="T: Thread ID"), patches.Patch(color="black", label="L: Local ID")]
ax.legend(
handles=legend_patches,
loc="upper right",
fontsize=font_size - 4,
frameon=False,
bbox_to_anchor=(legend_x, legend_y), # Dynamic position
ncols=2)
ncols=2,
)
# Create the output directory if it does not exist
tmp_directory = pathlib.Path(save_directory)
......@@ -211,28 +186,29 @@ def plot_layout(layout: T.Fragment,
if isinstance(formats, str):
formats_str = formats.strip().lower()
if formats_str == 'all':
formats_list = ['pdf', 'png', 'svg']
if formats_str == "all":
formats_list = ["pdf", "png", "svg"]
elif "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(',')]
formats_list = [f.strip() for f in formats_str.split(",")]
else:
formats_list = [formats_str]
else:
raise TypeError(f"Expected str, but got {type(formats).__name__}. "
f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.")
raise TypeError(
f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'."
)
# Save the figure
if 'pdf' in formats_list:
if "pdf" in formats_list:
pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight")
print(f"Saved pdf format into {pdf_path}")
if 'png' in formats_list:
if "png" in formats_list:
png_path = tmp_directory / f"{name}.png"
plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255)
print(f"Saved png format into {png_path}")
if 'svg' in formats_list:
if "svg" in formats_list:
svg_path = tmp_directory / f"{name}.svg"
plt.savefig(svg_path, bbox_inches="tight", format="svg")
print(f"Saved svg format into {svg_path}")
......@@ -110,8 +110,7 @@ def LowerHopperIntrin():
fpass : tvm.transform.Pass
The result pass
"""
return (_ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f
) # type: ignore
return _ffi_api.LowerHopperIntrin() if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore
def WarpSpecializedPipeline():
......@@ -365,8 +364,7 @@ def FlattenBuffer():
def EliminateStorageSyncForMBarrier():
"""EliminateStorageSyncForMBarrier
"""
"""EliminateStorageSyncForMBarrier"""
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
......@@ -378,19 +376,16 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_by
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge,
align_bytes) # type: ignore
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge, align_bytes) # type: ignore
def LowerL2Persistent():
"""LowerL2Persistent
"""
"""LowerL2Persistent"""
return _ffi_api.LowerL2Persistent() # type: ignore
def PersistThreadblock():
"""PersistThreadblock
"""
"""PersistThreadblock"""
return _ffi_api.PersistThreadblock() # type: ignore
......@@ -409,8 +404,7 @@ def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16):
def LowerSharedBarrier():
"""LowerSharedBarrier
"""
"""LowerSharedBarrier"""
return _ffi_api.LowerSharedBarrier() # type: ignore
......@@ -437,20 +431,17 @@ def StorageRewrite():
def LowerOpaqueBlock():
"""LowerOpaqueBlock
"""
"""LowerOpaqueBlock"""
return _ffi_api.LowerOpaqueBlock() # type: ignore
def LowerThreadAllreduce():
"""LowerThreadAllreduce
"""
"""LowerThreadAllreduce"""
return _ffi_api.LowerThreadAllreduce() # type: ignore
def LowerIntrin():
"""LowerIntrin
"""
"""LowerIntrin"""
return _ffi_api.LowerIntrin() # type: ignore
......@@ -468,8 +459,7 @@ def LowerDeviceKernelLaunch():
def LowerSharedTmem():
"""LowerSharedTmem
"""
"""LowerSharedTmem"""
return _ffi_api.LowerSharedTmem() # type: ignore
......
from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm)
from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm
from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass
......@@ -97,7 +97,7 @@ def AddWrapperForSingleBufStore():
Returns:
True if the loop is a tile operation (parallel or has num_stages annotation)
"""
return loop.kind == ForKind.PARALLEL or 'num_stages' in loop.annotations
return loop.kind == ForKind.PARALLEL or "num_stages" in loop.annotations
def pre_visit(statement):
"""
......@@ -105,7 +105,7 @@ def AddWrapperForSingleBufStore():
"""
nonlocal tile_operation_depth
if isinstance(statement, AttrStmt) and statement.attr_key == 'thread_extent':
if isinstance(statement, AttrStmt) and statement.attr_key == "thread_extent":
thread_binding_vars.add(statement.node.var)
elif isinstance(statement, For) and is_tile_operation_loop(statement):
tile_operation_depth += 1
......@@ -139,7 +139,8 @@ def AddWrapperForSingleBufStore():
if isinstance(index, IntImm) and index != 0:
raise ValueError(
f"Fragment buffer access with non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")
"Only fragment[0] access is allowed."
)
# Wrap fragment[0] access with T.Parallel loop
return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement)
......
......@@ -5,6 +5,7 @@ from enum import Enum
class PassConfigKey(str, Enum):
"""Pass configuration keys for TileLang compiler."""
# TileLang specific configs
TL_SIMPLIFY = "tl.Simplify"
"""Enable/disable TileLang simplification passes. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment