"...composable_kernel_onnxruntime.git" did not exist on "6790b8f3cc4fd32a9d9a43c6c9d80b826d969980"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
"""Reproduce: device_id mismatch (requires >=2 CUDA devices). """Reproduce: device_id mismatch (requires >=2 CUDA devices)."""
"""
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
...@@ -7,6 +7,7 @@ or a host-side non-NULL pointer check. ...@@ -7,6 +7,7 @@ or a host-side non-NULL pointer check.
Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script
demonstrates passing None, which still reproduces the intended class of failure. demonstrates passing None, which still reproduces the intended class of failure.
""" """
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
"""Reproduce: scalar parameter type mismatch (int/bool). """Reproduce: scalar parameter type mismatch (int/bool)."""
"""
from common import build_scalar_check_kernel from common import build_scalar_check_kernel
......
...@@ -3,20 +3,12 @@ import tilelang.language as T ...@@ -3,20 +3,12 @@ import tilelang.language as T
import torch import torch
def make_matmul_prim(M, def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype="float16", accum_dtype="float"):
N,
K,
block_M=128,
block_N=128,
block_K=32,
dtype="float16",
accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"): ...@@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"):
def build_scalar_check_kernel(target="cuda"): def build_scalar_check_kernel(target="cuda"):
@T.prim_func @T.prim_func
def scalar_check(x: T.int32, flag: T.bool()): def scalar_check(x: T.int32, flag: T.bool()):
T.evaluate(0) T.evaluate(0)
......
...@@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = { ...@@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = {
6: "sqrt", 6: "sqrt",
7: "tanh", 7: "tanh",
8: "rsqrt", 8: "rsqrt",
9: "inv_sqrt" 9: "inv_sqrt",
} }
# Block sizes for kernels # Block sizes for kernels
...@@ -49,8 +49,7 @@ TILELANG_THREADS = 128 ...@@ -49,8 +49,7 @@ TILELANG_THREADS = 128
def parse_arguments() -> argparse.Namespace: def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments.""" """Parse command line arguments."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Precision comparison tool for various CUDA implementations")
description="Precision comparison tool for various CUDA implementations")
parser.add_argument("--n", type=int, default=1000000, help="Number of elements to test") parser.add_argument("--n", type=int, default=1000000, help="Number of elements to test")
parser.add_argument("--low", type=float, default=-4.0, help="Lower bound for random values") parser.add_argument("--low", type=float, default=-4.0, help="Lower bound for random values")
parser.add_argument("--high", type=float, default=4.0, help="Upper bound for random values") parser.add_argument("--high", type=float, default=4.0, help="Upper bound for random values")
...@@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module: ...@@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module:
return load( return load(
name="cuda_ops", name="cuda_ops",
sources=["cuda_ops.cu"], sources=["cuda_ops.cu"],
extra_cuda_cflags=[] # No fast_math flags extra_cuda_cflags=[], # No fast_math flags
) )
...@@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S ...@@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S
@triton.jit @triton.jit
def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_SIZE: tl.constexpr):
BLOCK_SIZE: tl.constexpr):
"""LibDevice Triton kernel for unary operations.""" """LibDevice Triton kernel for unary operations."""
pid = tl.program_id(0) pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE block_start = pid * BLOCK_SIZE
...@@ -188,13 +186,10 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool = ...@@ -188,13 +186,10 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool =
@T.prim_func @T.prim_func
def tilelang_unary_kernel( def tilelang_unary_kernel(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32"),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by):
T.ceildiv(N, TILELANG_BLOCK_N),
T.ceildiv(M, TILELANG_BLOCK_M),
threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j col = bx * TILELANG_BLOCK_N + j
...@@ -229,14 +224,11 @@ def make_tilelang_binary_kernel(M: int, N: int): ...@@ -229,14 +224,11 @@ def make_tilelang_binary_kernel(M: int, N: int):
@T.prim_func @T.prim_func
def tilelang_binary_kernel( def tilelang_binary_kernel(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32"),
C: T.Tensor((M, N), "float32"), C: T.Tensor((M, N), "float32"),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by):
T.ceildiv(N, TILELANG_BLOCK_N),
T.ceildiv(M, TILELANG_BLOCK_M),
threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j col = bx * TILELANG_BLOCK_N + j
...@@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int): ...@@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int):
return tilelang_binary_kernel return tilelang_binary_kernel
def tilelang_op(x: torch.Tensor, def tilelang_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None, use_fastmath: bool = False) -> torch.Tensor:
op_id: int,
y: Optional[torch.Tensor] = None,
use_fastmath: bool = False) -> torch.Tensor:
"""TileLang operation interface.""" """TileLang operation interface."""
assert x.is_cuda assert x.is_cuda
...@@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor, ...@@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath,
}) },
)
out = kernel(x, y) out = kernel(x, y)
else: # Unary operation else: # Unary operation
kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath) kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath)
...@@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor, ...@@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath,
}) },
)
out = kernel(x) out = kernel(x)
# Restore original shape # Restore original shape
...@@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> ...@@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
"""Standard Triton operation interface.""" """Standard Triton operation interface."""
assert x.is_cuda assert x.is_cuda
out = torch.empty_like(x) out = torch.empty_like(x)
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
if op_id == 0: # Division - binary operation if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand" assert y is not None, "Division operation requires second operand"
...@@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> ...@@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
return out return out
def triton_libdevice_op(x: torch.Tensor, def triton_libdevice_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor:
op_id: int,
y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""LibDevice Triton operation interface.""" """LibDevice Triton operation interface."""
assert x.is_cuda assert x.is_cuda
out = torch.empty_like(x) out = torch.empty_like(x)
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
if op_id == 0: # Division - binary operation if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand" assert y is not None, "Division operation requires second operand"
...@@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor, ...@@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor,
return out return out
def get_pytorch_reference(x: torch.Tensor, def get_pytorch_reference(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor:
op_id: int,
y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Get PyTorch reference implementation for the given operation.""" """Get PyTorch reference implementation for the given operation."""
if op_id == 0: if op_id == 0:
assert y is not None, "Division requires second operand" assert y is not None, "Division requires second operand"
...@@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T ...@@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T
abs_err = (output_double - reference_double).abs() abs_err = (output_double - reference_double).abs()
rel_err = abs_err / (reference_double.abs().clamp_min(1e-30)) rel_err = abs_err / (reference_double.abs().clamp_min(1e-30))
print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " print(
f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}") f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}"
)
# Precision comparison function # Precision comparison function
...@@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No ...@@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
results[name] = None results[name] = None
# Print comparison header # Print comparison header
print( print(f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}")
f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}"
)
print("-" * 90) print("-" * 90)
# Compare all implementations against double precision reference # Compare all implementations against double precision reference
...@@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No ...@@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
summarize_error(tag, output, ref_double) summarize_error(tag, output, ref_double)
def generate_test_data(op_id: int, n: int, device: torch.device, low: float, def generate_test_data(op_id: int, n: int, device: torch.device, low: float, high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Generate appropriate test data for each operation.""" """Generate appropriate test data for each operation."""
if op_id == 0: # Division if op_id == 0: # Division
x = torch.empty(n, device=device).uniform_(low, high) x = torch.empty(n, device=device).uniform_(low, high)
...@@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float, ...@@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float,
def main() -> None: def main() -> None:
"""Main execution function.""" """Main execution function."""
print( print("Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang")
"Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang"
)
print("=" * 90) print("=" * 90)
for op_id in range(len(OP_NAMES)): for op_id in range(len(OP_NAMES)):
......
...@@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1" ...@@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1"
def parse_output(output): def parse_output(output):
data = {} data = {}
for line in output.split('\n'): for line in output.split("\n"):
line = line.strip() line = line.strip()
if line.startswith('Latency:'): if line.startswith("Latency:"):
match = re.search(r'Latency: ([\d.]+)', line) match = re.search(r"Latency: ([\d.]+)", line)
data['latency'] = match.group(1) if match else 'N/A' data["latency"] = match.group(1) if match else "N/A"
elif line.startswith('TFlops:'): elif line.startswith("TFlops:"):
match = re.search(r'TFlops: ([\d.]+)', line) match = re.search(r"TFlops: ([\d.]+)", line)
data['best_tflops'] = match.group(1) if match else 'N/A' data["best_tflops"] = match.group(1) if match else "N/A"
elif line.startswith('Config:'): elif line.startswith("Config:"):
data['config'] = line.split('Config: ')[-1] data["config"] = line.split("Config: ")[-1]
elif line.startswith('Reference TFlops:'): elif line.startswith("Reference TFlops:"):
match = re.search(r'Reference TFlops: ([\d.]+)', line) match = re.search(r"Reference TFlops: ([\d.]+)", line)
data['ref_tflops'] = match.group(1) if match else 'N/A' data["ref_tflops"] = match.group(1) if match else "N/A"
return data return data
output_v1 = subprocess.run(['./tl/bin/python', './maint/scripts/performance.py'], output_v1 = subprocess.run(["./tl/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout
capture_output=True,
text=True,
env=env).stdout
data_v1 = parse_output(output_v1) data_v1 = parse_output(output_v1)
output_v2 = subprocess.run(['./tll/bin/python', './maint/scripts/performance.py'], output_v2 = subprocess.run(["./tll/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout
capture_output=True,
text=True,
env=env).stdout
data_v2 = parse_output(output_v2) data_v2 = parse_output(output_v2)
table = [[ table = [
"original", data_v1['latency'], data_v1['best_tflops'], data_v1['ref_tflops'], data_v1['config'] ["original", data_v1["latency"], data_v1["best_tflops"], data_v1["ref_tflops"], data_v1["config"]],
], [ ["current", data_v2["latency"], data_v2["best_tflops"], data_v2["ref_tflops"], data_v2["config"]],
"current", data_v2['latency'], data_v2['best_tflops'], data_v2['ref_tflops'], data_v2['config'] ]
]]
headers = ["version", "Best Latency (s)", "Best TFlops", "Reference TFlops", "Best Config"] headers = ["version", "Best Latency (s)", "Best TFlops", "Reference TFlops", "Best Config"]
......
...@@ -8,19 +8,20 @@ def ref_program(A, B): ...@@ -8,19 +8,20 @@ def ref_program(A, B):
def get_configs(): def get_configs():
configs = [{ configs = [
"block_M": 128, {
"block_N": 128, "block_M": 128,
"block_K": 64, "block_N": 128,
"num_stages": 2, "block_K": 64,
"thread_num": 256, "num_stages": 2,
"enable_rasteration": True, # keep param name for backward-compat "thread_num": 256,
}] "enable_rasteration": True, # keep param name for backward-compat
}
]
return configs return configs
def run(M, N, K): def run(M, N, K):
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
...@@ -34,12 +35,11 @@ def run(M, N, K): ...@@ -34,12 +35,11 @@ def run(M, N, K):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -60,12 +60,16 @@ def run(M, N, K): ...@@ -60,12 +60,16 @@ def run(M, N, K):
return main return main
autotuner = AutoTuner.from_kernel( autotuner = (
kernel=kernel, configs=get_configs()).set_compile_args( AutoTuner.from_kernel(kernel=kernel, configs=get_configs())
.set_compile_args(
out_idx=[-1], out_idx=[-1],
target="auto", target="auto",
).set_profile_args( )
ref_prog=ref_program,) .set_profile_args(
ref_prog=ref_program,
)
)
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
......
...@@ -122,10 +122,7 @@ tilelang = "tilelang" ...@@ -122,10 +122,7 @@ tilelang = "tilelang"
"tilelang/3rdparty/composable_kernel/include" = "3rdparty/composable_kernel/include" "tilelang/3rdparty/composable_kernel/include" = "3rdparty/composable_kernel/include"
"tilelang/3rdparty/composable_kernel/library" = "3rdparty/composable_kernel/library" "tilelang/3rdparty/composable_kernel/library" = "3rdparty/composable_kernel/library"
[tool.yapf]
based_on_style = "yapf"
column_limit = 100
indent_width = 4
[tool.codespell] [tool.codespell]
ignore-words = "docs/spelling_wordlist.txt" ignore-words = "docs/spelling_wordlist.txt"
...@@ -138,7 +135,7 @@ skip = [ ...@@ -138,7 +135,7 @@ skip = [
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py39"
line-length = 100 line-length = 140
output-format = "full" output-format = "full"
exclude = [ exclude = [
...@@ -146,6 +143,14 @@ exclude = [ ...@@ -146,6 +143,14 @@ exclude = [
"examples/deepseek_v32/inference", "examples/deepseek_v32/inference",
] ]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
docstring-code-line-length = "dynamic"
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
# Do not upgrade type hint in testing and examples. # Do not upgrade type hint in testing and examples.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information. # See https://github.com/tile-ai/tilelang/issues/1079 for more information.
......
...@@ -4,4 +4,3 @@ clang-format==21.1.2 ...@@ -4,4 +4,3 @@ clang-format==21.1.2
clang-tidy==21.1.1 clang-tidy==21.1.1
codespell[toml]==2.4.1 codespell[toml]==2.4.1
ruff==0.14.3 ruff==0.14.3
yapf==0.43.0
...@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): ...@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings", "warnings",
"error", "error",
} }
if (sum( if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0:
len(terminalreporter.stats.get(k, []))
for k in known_types.difference({"skipped", "deselected"})) == 0):
terminalreporter.write_sep( terminalreporter.write_sep(
"!", "!",
(f"Error: No tests were collected. " (f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
) )
pytest.exit("No tests were collected.", returncode=5) pytest.exit("No tests were collected.", returncode=5)
...@@ -4,7 +4,8 @@ from tilelang import tvm as tvm ...@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import ( from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,) MatrixCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -22,7 +23,6 @@ def tl_matmul( ...@@ -22,7 +23,6 @@ def tl_matmul(
b_transposed=True, b_transposed=True,
k_pack=1, k_pack=1,
): ):
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}: if in_dtype in {"float8_e4m3fnuz", "int8"}:
...@@ -78,12 +78,11 @@ def tl_matmul( ...@@ -78,12 +78,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -91,10 +90,12 @@ def tl_matmul( ...@@ -91,10 +90,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -102,7 +103,6 @@ def tl_matmul( ...@@ -102,7 +103,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=0): for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory # Load A into shared memory
if a_transposed: if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared) T.copy(A[ko * block_K, by * block_M], A_shared)
...@@ -116,7 +116,6 @@ def tl_matmul( ...@@ -116,7 +116,6 @@ def tl_matmul(
T.copy(B[ko * block_K, bx * block_N], B_shared) T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
A_local, A_local,
...@@ -160,17 +159,8 @@ def tl_matmul( ...@@ -160,17 +159,8 @@ def tl_matmul(
return main return main
def assert_tl_matmul_correctness(M, def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1):
N, matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack)
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack)
print(matmul) print(matmul)
kernel = tilelang.compile(matmul) kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
...@@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M, ...@@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M,
if a_transposed and b_transposed: if a_transposed and b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32), ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed: elif a_transposed and not b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32), ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed: elif not a_transposed and b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else: else:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
...@@ -228,16 +215,13 @@ def test_assert_tl_matmul(): ...@@ -228,16 +215,13 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16") assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,7 +23,6 @@ def tl_matmul( ...@@ -23,7 +23,6 @@ def tl_matmul(
b_preshuffle=False, b_preshuffle=False,
b_g2l_load=False, b_g2l_load=False,
): ):
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}: if in_dtype in {"float8_e4m3fnuz", "int8"}:
...@@ -53,18 +52,21 @@ def tl_matmul( ...@@ -53,18 +52,21 @@ def tl_matmul(
A_shape = (K, M) if a_transposed else (M, K) A_shape = (K, M) if a_transposed else (M, K)
if b_preshuffle: if b_preshuffle:
B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y, B_shape = (
pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k)
pack_size_k, micro_size_y) if b_transposed
else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y)
)
else: else:
B_shape = (N, K) if b_transposed else (K, N) B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
if b_preshuffle: if b_preshuffle:
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, B_shared_shape = (
pack_size_k) if b_transposed else (block_K // pack_size_k, (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k)
block_N // micro_size_y, pack_size_k, if b_transposed
micro_size_y) else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y)
)
else: else:
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
...@@ -94,21 +96,22 @@ def tl_matmul( ...@@ -94,21 +96,22 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
}) A_shared: make_swizzle_layout(A_shared),
}
)
num_ko = K // block_K num_ko = K // block_K
num_ki = block_K // (k_pack * micro_size_k) num_ki = block_K // (k_pack * micro_size_k)
...@@ -119,7 +122,6 @@ def tl_matmul( ...@@ -119,7 +122,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined(num_ko, num_stages=0): for ko in T.Pipelined(num_ko, num_stages=0):
# Load A into shared memory # Load A into shared memory
if a_transposed: if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared) T.copy(A[ko * block_K, by * block_M], A_shared)
...@@ -129,20 +131,13 @@ def tl_matmul( ...@@ -129,20 +131,13 @@ def tl_matmul(
# Load B into shared memory # Load B into shared memory
if b_g2l_load is False: if b_g2l_load is False:
if b_transposed: if b_transposed:
for j, k, jj, kk in T.Parallel(block_N // micro_size_y, for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k):
block_K // pack_size_k, micro_size_y, B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk]
pack_size_k):
B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j,
ko * block_K // pack_size_k + k, jj, kk]
else: else:
for k, j, kk, jj in T.Parallel(block_K // pack_size_k, for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y):
block_N // micro_size_y, pack_size_k, B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj]
micro_size_y):
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
bx * block_N // micro_size_y + j, kk, jj]
for ki in T.serial(0, num_ki): for ki in T.serial(0, num_ki):
# Load A S2L # Load A S2L
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
A_local, A_local,
...@@ -176,10 +171,10 @@ def tl_matmul( ...@@ -176,10 +171,10 @@ def tl_matmul(
def shuffle_weight( def shuffle_weight(
x: torch.Tensor, x: torch.Tensor,
layout=(16, 32), layout=(16, 32),
k_pack=1, k_pack=1,
is_transpose=False, is_transpose=False,
) -> torch.Tensor: ) -> torch.Tensor:
IN, IK = layout IN, IK = layout
BK = IK * k_pack BK = IK * k_pack
...@@ -194,19 +189,20 @@ def shuffle_weight( ...@@ -194,19 +189,20 @@ def shuffle_weight(
return x.contiguous() return x.contiguous()
def assert_tl_matmul_correctness(M, def assert_tl_matmul_correctness(
N, M,
K, N,
in_dtype, K,
out_dtype, in_dtype,
accum_dtype="float32", out_dtype,
a_transposed=False, accum_dtype="float32",
b_transposed=True, a_transposed=False,
k_pack=1, b_transposed=True,
b_preshuffle=False, k_pack=1,
b_g2l_load=False): b_preshuffle=False,
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, b_g2l_load=False,
k_pack, b_preshuffle, b_g2l_load) ):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load)
print(matmul) print(matmul)
kernel = tilelang.compile(matmul) kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
...@@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M, ...@@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M,
if a_transposed and b_transposed: if a_transposed and b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32), ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed: elif a_transposed and not b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32), ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed: elif not a_transposed and b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else: else:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
...@@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M, ...@@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M,
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness( assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"int8",
"int32",
b_transposed=False,
accum_dtype="int32",
k_pack=2,
b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True) assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_transposed=False, b_preshuffle=True)
256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"float8_e4m3fnuz",
"float32",
k_pack=2,
b_transposed=False,
b_preshuffle=True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -27,8 +27,7 @@ def matmul( ...@@ -27,8 +27,7 @@ def matmul(
vec_size = 4 * k_pack vec_size = 4 * k_pack
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
...@@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt(): ...@@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt():
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm( run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2)
1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
...@@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32(): ...@@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32():
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm( run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2)
1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2)
def matmul_rs( def matmul_rs(
...@@ -149,9 +146,9 @@ def matmul_rs( ...@@ -149,9 +146,9 @@ def matmul_rs(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
...@@ -5,14 +5,12 @@ import pytest ...@@ -5,14 +5,12 @@ import pytest
@tilelang.jit @tilelang.jit
def simple_invalid_loop(dtype: str = "bfloat16", def simple_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A") A = T.dynamic("A")
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((128, A), dtype), # type: ignore data: T.Tensor((128, A), dtype), # type: ignore
): ):
with T.Kernel(128, threads=num_threads) as (tid,): with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype) data_frag = T.alloc_fragment([128], accum_dtype)
...@@ -28,14 +26,12 @@ def simple_invalid_loop(dtype: str = "bfloat16", ...@@ -28,14 +26,12 @@ def simple_invalid_loop(dtype: str = "bfloat16",
@tilelang.jit @tilelang.jit
def nested_invalid_loop(dtype: str = "bfloat16", def nested_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A") A = T.dynamic("A")
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((128, A), dtype), # type: ignore data: T.Tensor((128, A), dtype), # type: ignore
): ):
with T.Kernel(128, threads=num_threads) as (tid,): with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype) data_frag = T.alloc_fragment([128], accum_dtype)
...@@ -52,14 +48,12 @@ def nested_invalid_loop(dtype: str = "bfloat16", ...@@ -52,14 +48,12 @@ def nested_invalid_loop(dtype: str = "bfloat16",
@tilelang.jit @tilelang.jit
def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A") A = T.dynamic("A")
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((128, A), dtype), # type: ignore data: T.Tensor((128, A), dtype), # type: ignore
): ):
with T.Kernel(128, threads=num_threads) as (tid,): with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype) data_frag = T.alloc_fragment([128], accum_dtype)
...@@ -75,14 +69,12 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", ...@@ -75,14 +69,12 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
@tilelang.jit @tilelang.jit
def valid_loop_not_use_loop_var(dtype: str = "bfloat16", def valid_loop_not_use_loop_var(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A") A = T.dynamic("A")
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((128, A), dtype), # type: ignore data: T.Tensor((128, A), dtype), # type: ignore
): ):
with T.Kernel(128, threads=num_threads) as (tid,): with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype) data_frag = T.alloc_fragment([128], accum_dtype)
...@@ -99,14 +91,12 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16", ...@@ -99,14 +91,12 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
@tilelang.jit @tilelang.jit
def valid_loop_not_frag(dtype: str = "bfloat16", def valid_loop_not_frag(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A") A = T.dynamic("A")
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((128, A), dtype), # type: ignore data: T.Tensor((128, A), dtype), # type: ignore
): ):
with T.Kernel(128, threads=num_threads) as (tid,): with T.Kernel(128, threads=num_threads) as (tid,):
data_shared = T.alloc_shared([128], accum_dtype) data_shared = T.alloc_shared([128], accum_dtype)
...@@ -122,14 +112,12 @@ def valid_loop_not_frag(dtype: str = "bfloat16", ...@@ -122,14 +112,12 @@ def valid_loop_not_frag(dtype: str = "bfloat16",
@tilelang.jit @tilelang.jit
def valid_loop_serial(dtype: str = "bfloat16", def valid_loop_serial(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A") A = T.dynamic("A")
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((128, A), dtype), # type: ignore data: T.Tensor((128, A), dtype), # type: ignore
): ):
with T.Kernel(128, threads=num_threads) as (tid,): with T.Kernel(128, threads=num_threads) as (tid,):
data_shared = T.alloc_shared([128], accum_dtype) data_shared = T.alloc_shared([128], accum_dtype)
......
...@@ -30,11 +30,10 @@ Rule: ...@@ -30,11 +30,10 @@ Rule:
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_continuous_parallels(length=256, block=16, dtype="float32"): def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block): for i in T.Parallel(length // block):
...@@ -46,29 +45,26 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"): ...@@ -46,29 +45,26 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"): def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block1 // block2): for i in T.Parallel(length // block1 // block2):
for j in T.Parallel(block1): for j in T.Parallel(block1):
for k in T.Parallel(block2): for k in T.Parallel(block2):
B[i * block1 * block2 + j * block2 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main return main
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"): def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block): for i in T.Parallel(length // block):
...@@ -103,8 +99,9 @@ is OK. ...@@ -103,8 +99,9 @@ is OK.
""" """
def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, def matmul_nested_pipelines(
out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats): M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats
):
A_shape = (K, M) if trans_A else (M, K) A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N) B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
...@@ -114,9 +111,9 @@ def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B ...@@ -114,9 +111,9 @@ def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -180,7 +177,8 @@ def run_gemm_nested_pipelines( ...@@ -180,7 +177,8 @@ def run_gemm_nested_pipelines(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
...@@ -193,8 +191,8 @@ def run_gemm_nested_pipelines( ...@@ -193,8 +191,8 @@ def run_gemm_nested_pipelines(
if in_dtype == "float32": if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas # float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
...@@ -218,11 +216,10 @@ is OK. ...@@ -218,11 +216,10 @@ is OK.
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_continuous_serials(length=256, block=16, dtype="float32"): def nested_continuous_serials(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block): for i in T.serial(length // block):
...@@ -234,11 +231,10 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"): ...@@ -234,11 +231,10 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_noncontinuous_serials(length=256, block=16, dtype="float32"): def nested_noncontinuous_serials(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block): for i in T.serial(length // block):
...@@ -277,11 +273,10 @@ Rule: ...@@ -277,11 +273,10 @@ Rule:
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_continuous_sp(length=256, block=16, dtype="float32"): def nested_continuous_sp(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block): for i in T.serial(length // block):
...@@ -293,11 +288,10 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"): ...@@ -293,11 +288,10 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_continuous_ps(length=256, block=16, dtype="float32"): def nested_continuous_ps(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block): for i in T.Parallel(length // block):
...@@ -309,36 +303,32 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"): ...@@ -309,36 +303,32 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"): def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block1 // block2): for i in T.Parallel(length // block1 // block2):
for j in T.serial(block1): for j in T.serial(block1):
for k in T.Parallel(block2): for k in T.Parallel(block2):
B[i * block1 * block2 + j * block2 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main return main
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"): def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.serial(length // block1 // block2): for i in T.serial(length // block1 // block2):
for j in T.Parallel(block1): for j in T.Parallel(block1):
for k in T.serial(block2): for k in T.serial(block2):
B[i * block1 * block2 + j * block2 + B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main return main
...@@ -399,9 +389,9 @@ def matmul_nested_pipa( ...@@ -399,9 +389,9 @@ def matmul_nested_pipa(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -444,9 +434,9 @@ def matmul_nested_papipa( ...@@ -444,9 +434,9 @@ def matmul_nested_papipa(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -505,7 +495,8 @@ def run_gemm_mixed_pp( ...@@ -505,7 +495,8 @@ def run_gemm_mixed_pp(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
...@@ -514,8 +505,8 @@ def run_gemm_mixed_pp( ...@@ -514,8 +505,8 @@ def run_gemm_mixed_pp(
if in_dtype == "float32": if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas # float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
...@@ -543,7 +534,8 @@ def run_gemm_mixed_pp( ...@@ -543,7 +534,8 @@ def run_gemm_mixed_pp(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
def test_mixed_pp(): def test_mixed_pp():
...@@ -576,9 +568,9 @@ def matmul_with_parallel( ...@@ -576,9 +568,9 @@ def matmul_with_parallel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel( ...@@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
...@@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel( ...@@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel(
if in_dtype == "float32": if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas # float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
...@@ -675,16 +668,16 @@ def run_gemm_tiled_op_with_parallel( ...@@ -675,16 +668,16 @@ def run_gemm_tiled_op_with_parallel(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def tir_op_with_parallel(length=256, block=16, dtype="float32"): def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block): for i in T.Parallel(length // block):
...@@ -696,11 +689,10 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"): ...@@ -696,11 +689,10 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def customize_op_with_parallel(length=256, block=16, dtype="float32"): def customize_op_with_parallel(length=256, block=16, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block): for i in T.Parallel(length // block):
......
...@@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False): ...@@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False):
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda") arch = CUDA("cuda")
topk = 20 topk = 20
...@@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False): ...@@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False):
for config in configs: for config in configs:
print(config) print(config)
else: else:
block_M = [64] block_M = [64]
block_N = [64] block_N = [64]
block_K = [32] block_K = [32]
...@@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False): ...@@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False):
num_stages, num_stages,
thread_num, thread_num,
enable_rasterization, enable_rasterization,
)) )
)
configs = [ configs = [
{ {
...@@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False): ...@@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False):
"num_stages": c[3], "num_stages": c[3],
"thread_num": c[4], "thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat "enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs }
for c in _configs
] ]
return configs return configs
...@@ -190,9 +192,9 @@ def matmul(M, N, K, with_roller): ...@@ -190,9 +192,9 @@ def matmul(M, N, K, with_roller):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
...@@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller): ...@@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller):
""" """
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
...@@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller): ...@@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller):
return main return main
autotuner = AutoTuner.from_kernel( autotuner = (
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
.set_compile_args(
out_idx=[-1], out_idx=[-1],
target="auto", target="auto",
).set_profile_args( )
ref_prog=ref_program,) .set_profile_args(
ref_prog=ref_program,
)
)
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
......
...@@ -30,38 +30,23 @@ def ref_program(A, B): ...@@ -30,38 +30,23 @@ def ref_program(A, B):
def get_configs(): def get_configs():
iter_params = dict( iter_params = dict(block_M=[64], block_N=[64], block_K=[32], num_stages=[0, 1], thread_num=[128], enable_rasterization=[False])
block_M=[64], return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
block_N=[64],
block_K=[32],
num_stages=[0, 1],
thread_num=[128],
enable_rasterization=[False])
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
block_M=128,
block_N=128,
block_K=32,
num_stages=0,
thread_num=128,
enable_rasterization=False):
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
...@@ -76,7 +61,6 @@ def matmul(M, ...@@ -76,7 +61,6 @@ def matmul(M,
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
......
...@@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -63,6 +63,7 @@ def run_cache_matmul(): ...@@ -63,6 +63,7 @@ def run_cache_matmul():
Reference PyTorch matrix multiplication for comparison. Reference PyTorch matrix multiplication for comparison.
""" """
import torch import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.half) # Assuming dtype="float16" in matmul C = C.to(torch.half) # Assuming dtype="float16" in matmul
return C return C
......
...@@ -29,9 +29,7 @@ class _cudaDeviceAttrNames: ...@@ -29,9 +29,7 @@ class _cudaDeviceAttrNames:
def test_driver_get_device_properties(): def test_driver_get_device_properties():
prop = get_cuda_device_properties() prop = get_cuda_device_properties()
assert prop is not None, "Failed to get CUDA device properties" assert prop is not None, "Failed to get CUDA device properties"
assert isinstance( assert isinstance(prop, torch.cuda._CudaDeviceProperties), "Returned object is not of type _CudaDeviceProperties"
prop,
torch.cuda._CudaDeviceProperties), ("Returned object is not of type _CudaDeviceProperties")
def test_device_get_device_name(): def test_device_get_device_name():
...@@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block(): ...@@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block():
def test_device_get_persisting_l2_cache_size(): def test_device_get_persisting_l2_cache_size():
tl_cache_size = get_persisting_l2_cache_max_size() tl_cache_size = get_persisting_l2_cache_max_size()
driver_cache_size = get_device_attribute( driver_cache_size = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize)
_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize)
assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match" assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match"
...@@ -61,17 +58,14 @@ def test_device_get_num_sms(): ...@@ -61,17 +58,14 @@ def test_device_get_num_sms():
def test_device_get_registers_per_block(): def test_device_get_registers_per_block():
tl_regs_per_block = get_registers_per_block() tl_regs_per_block = get_registers_per_block()
driver_regs_per_block = get_device_attribute( driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock)
_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock)
assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match"
def test_device_get_max_dynamic_shared_size_bytes(): def test_device_get_max_dynamic_shared_size_bytes():
tl_dynamic_smem = get_max_dynamic_shared_size_bytes() tl_dynamic_smem = get_max_dynamic_shared_size_bytes()
driver_dynamic_smem = get_device_attribute( driver_dynamic_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor)
_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) assert tl_dynamic_smem == driver_dynamic_smem, "Max dynamic shared size bytes values do not match"
assert tl_dynamic_smem == driver_dynamic_smem, (
"Max dynamic shared size bytes values do not match")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20): ...@@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
def gemm(M, N, K): def gemm(M, N, K):
A = te.placeholder((M, K), name='A', dtype='float16') A = te.placeholder((M, K), name="A", dtype="float16")
B = te.placeholder((N, K), name='B', dtype='float16') B = te.placeholder((N, K), name="B", dtype="float16")
# Describe the matrix multiplication in TE # Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name='k') k = te.reduce_axis((0, K), name="k")
C = te.compute( C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C")
(M, N),
lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
name='C')
return A, B, C return A, B, C
...@@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20): ...@@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target) tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target)
print(tags) print(tags)
policy = carver.TensorCorePolicy.from_prim_func( policy = carver.TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags, name="matmul_0")
func=tensorized_func, arch=arch, tags=tags, name="matmul_0")
hints = policy.emit_config(topk=topk) hints = policy.emit_config(topk=topk)
...@@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20): ...@@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
def gemm(M, N, K): def gemm(M, N, K):
A = te.placeholder((M, K), name='A', dtype='float16') A = te.placeholder((M, K), name="A", dtype="float16")
B = te.placeholder((N, K), name='B', dtype='float16') B = te.placeholder((N, K), name="B", dtype="float16")
# Describe the matrix multiplication in TE # Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name='k') k = te.reduce_axis((0, K), name="k")
C = te.compute( C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C")
(M, N),
lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
name='C')
return A, B, C return A, B, C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment