Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
"""Reproduce: device_id mismatch (requires >=2 CUDA devices).
"""
"""Reproduce: device_id mismatch (requires >=2 CUDA devices)."""
import torch
from common import build_matmul_kernel
......
......@@ -7,6 +7,7 @@ or a host-side non-NULL pointer check.
Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script
demonstrates passing None, which still reproduces the intended class of failure.
"""
import torch
from common import build_matmul_kernel
......
"""Reproduce: scalar parameter type mismatch (int/bool).
"""
"""Reproduce: scalar parameter type mismatch (int/bool)."""
from common import build_scalar_check_kernel
......
......@@ -3,15 +3,7 @@ import tilelang.language as T
import torch
def make_matmul_prim(M,
N,
K,
block_M=128,
block_N=128,
block_K=32,
dtype="float16",
accum_dtype="float"):
def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"):
def build_scalar_check_kernel(target="cuda"):
@T.prim_func
def scalar_check(x: T.int32, flag: T.bool()):
T.evaluate(0)
......
......@@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = {
6: "sqrt",
7: "tanh",
8: "rsqrt",
9: "inv_sqrt"
9: "inv_sqrt",
}
# Block sizes for kernels
......@@ -49,8 +49,7 @@ TILELANG_THREADS = 128
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Precision comparison tool for various CUDA implementations")
parser = argparse.ArgumentParser(description="Precision comparison tool for various CUDA implementations")
parser.add_argument("--n", type=int, default=1000000, help="Number of elements to test")
parser.add_argument("--low", type=float, default=-4.0, help="Lower bound for random values")
parser.add_argument("--high", type=float, default=4.0, help="Upper bound for random values")
......@@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module:
return load(
name="cuda_ops",
sources=["cuda_ops.cu"],
extra_cuda_cflags=[] # No fast_math flags
extra_cuda_cflags=[], # No fast_math flags
)
......@@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S
@triton.jit
def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
def triton_libdevice_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_SIZE: tl.constexpr):
"""LibDevice Triton kernel for unary operations."""
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
......@@ -191,10 +189,7 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool =
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(
T.ceildiv(N, TILELANG_BLOCK_N),
T.ceildiv(M, TILELANG_BLOCK_M),
threads=TILELANG_THREADS) as (bx, by):
with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
......@@ -233,10 +228,7 @@ def make_tilelang_binary_kernel(M: int, N: int):
B: T.Tensor((M, N), "float32"),
C: T.Tensor((M, N), "float32"),
):
with T.Kernel(
T.ceildiv(N, TILELANG_BLOCK_N),
T.ceildiv(M, TILELANG_BLOCK_M),
threads=TILELANG_THREADS) as (bx, by):
with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
......@@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int):
return tilelang_binary_kernel
def tilelang_op(x: torch.Tensor,
op_id: int,
y: Optional[torch.Tensor] = None,
use_fastmath: bool = False) -> torch.Tensor:
def tilelang_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None, use_fastmath: bool = False) -> torch.Tensor:
"""TileLang operation interface."""
assert x.is_cuda
......@@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath,
})
},
)
out = kernel(x, y)
else: # Unary operation
kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath)
......@@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: use_fastmath,
})
},
)
out = kernel(x)
# Restore original shape
......@@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
"""Standard Triton operation interface."""
assert x.is_cuda
out = torch.empty_like(x)
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
......@@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
return out
def triton_libdevice_op(x: torch.Tensor,
op_id: int,
y: Optional[torch.Tensor] = None) -> torch.Tensor:
def triton_libdevice_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""LibDevice Triton operation interface."""
assert x.is_cuda
out = torch.empty_like(x)
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
grid = lambda meta: ((x.numel() + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)
if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
......@@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor,
return out
def get_pytorch_reference(x: torch.Tensor,
op_id: int,
y: Optional[torch.Tensor] = None) -> torch.Tensor:
def get_pytorch_reference(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Get PyTorch reference implementation for the given operation."""
if op_id == 0:
assert y is not None, "Division requires second operand"
......@@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T
abs_err = (output_double - reference_double).abs()
rel_err = abs_err / (reference_double.abs().clamp_min(1e-30))
print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}")
print(
f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}"
)
# Precision comparison function
......@@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
results[name] = None
# Print comparison header
print(
f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}"
)
print(f"{'Implementation':<32} {'Max Abs Error':<19} {'Mean Abs Error':<20} {'Max Rel Error':<19} {'Mean Rel Error'}")
print("-" * 90)
# Compare all implementations against double precision reference
......@@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
summarize_error(tag, output, ref_double)
def generate_test_data(op_id: int, n: int, device: torch.device, low: float,
high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def generate_test_data(op_id: int, n: int, device: torch.device, low: float, high: float) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Generate appropriate test data for each operation."""
if op_id == 0: # Division
x = torch.empty(n, device=device).uniform_(low, high)
......@@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float,
def main() -> None:
"""Main execution function."""
print(
"Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang"
)
print("Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang")
print("=" * 90)
for op_id in range(len(OP_NAMES)):
......
......@@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1"
def parse_output(output):
data = {}
for line in output.split('\n'):
for line in output.split("\n"):
line = line.strip()
if line.startswith('Latency:'):
match = re.search(r'Latency: ([\d.]+)', line)
data['latency'] = match.group(1) if match else 'N/A'
elif line.startswith('TFlops:'):
match = re.search(r'TFlops: ([\d.]+)', line)
data['best_tflops'] = match.group(1) if match else 'N/A'
elif line.startswith('Config:'):
data['config'] = line.split('Config: ')[-1]
elif line.startswith('Reference TFlops:'):
match = re.search(r'Reference TFlops: ([\d.]+)', line)
data['ref_tflops'] = match.group(1) if match else 'N/A'
if line.startswith("Latency:"):
match = re.search(r"Latency: ([\d.]+)", line)
data["latency"] = match.group(1) if match else "N/A"
elif line.startswith("TFlops:"):
match = re.search(r"TFlops: ([\d.]+)", line)
data["best_tflops"] = match.group(1) if match else "N/A"
elif line.startswith("Config:"):
data["config"] = line.split("Config: ")[-1]
elif line.startswith("Reference TFlops:"):
match = re.search(r"Reference TFlops: ([\d.]+)", line)
data["ref_tflops"] = match.group(1) if match else "N/A"
return data
output_v1 = subprocess.run(['./tl/bin/python', './maint/scripts/performance.py'],
capture_output=True,
text=True,
env=env).stdout
output_v1 = subprocess.run(["./tl/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout
data_v1 = parse_output(output_v1)
output_v2 = subprocess.run(['./tll/bin/python', './maint/scripts/performance.py'],
capture_output=True,
text=True,
env=env).stdout
output_v2 = subprocess.run(["./tll/bin/python", "./maint/scripts/performance.py"], capture_output=True, text=True, env=env).stdout
data_v2 = parse_output(output_v2)
table = [[
"original", data_v1['latency'], data_v1['best_tflops'], data_v1['ref_tflops'], data_v1['config']
], [
"current", data_v2['latency'], data_v2['best_tflops'], data_v2['ref_tflops'], data_v2['config']
]]
table = [
["original", data_v1["latency"], data_v1["best_tflops"], data_v1["ref_tflops"], data_v1["config"]],
["current", data_v2["latency"], data_v2["best_tflops"], data_v2["ref_tflops"], data_v2["config"]],
]
headers = ["version", "Best Latency (s)", "Best TFlops", "Reference TFlops", "Best Config"]
......
......@@ -8,19 +8,20 @@ def ref_program(A, B):
def get_configs():
configs = [{
configs = [
{
"block_M": 128,
"block_N": 128,
"block_K": 64,
"num_stages": 2,
"thread_num": 256,
"enable_rasteration": True, # keep param name for backward-compat
}]
}
]
return configs
def run(M, N, K):
def kernel(
block_M=None,
block_N=None,
......@@ -38,8 +39,7 @@ def run(M, N, K):
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -60,12 +60,16 @@ def run(M, N, K):
return main
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs()).set_compile_args(
autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs())
.set_compile_args(
out_idx=[-1],
target="auto",
).set_profile_args(
ref_prog=ref_program,)
)
.set_profile_args(
ref_prog=ref_program,
)
)
return autotuner.run(warmup=3, rep=20)
......
......@@ -122,10 +122,7 @@ tilelang = "tilelang"
"tilelang/3rdparty/composable_kernel/include" = "3rdparty/composable_kernel/include"
"tilelang/3rdparty/composable_kernel/library" = "3rdparty/composable_kernel/library"
[tool.yapf]
based_on_style = "yapf"
column_limit = 100
indent_width = 4
[tool.codespell]
ignore-words = "docs/spelling_wordlist.txt"
......@@ -138,7 +135,7 @@ skip = [
[tool.ruff]
target-version = "py39"
line-length = 100
line-length = 140
output-format = "full"
exclude = [
......@@ -146,6 +143,14 @@ exclude = [
"examples/deepseek_v32/inference",
]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
docstring-code-line-length = "dynamic"
[tool.ruff.lint.per-file-ignores]
# Do not upgrade type hint in testing and examples.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
......
......@@ -4,4 +4,3 @@ clang-format==21.1.2
clang-tidy==21.1.1
codespell[toml]==2.4.1
ruff==0.14.3
yapf==0.43.0
......@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings",
"error",
}
if (sum(
len(terminalreporter.stats.get(k, []))
for k in known_types.difference({"skipped", "deselected"})) == 0):
if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0:
terminalreporter.write_sep(
"!",
(f"Error: No tests were collected. "
f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
(f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
)
pytest.exit("No tests were collected.", returncode=5)
......@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
MatrixCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
......@@ -22,7 +23,6 @@ def tl_matmul(
b_transposed=True,
k_pack=1,
):
micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}:
......@@ -83,7 +83,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -91,10 +90,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -102,7 +103,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory
if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared)
......@@ -116,7 +116,6 @@ def tl_matmul(
T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
......@@ -160,17 +159,8 @@ def tl_matmul(
return main
def assert_tl_matmul_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack)
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
......@@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M,
if a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
B.to(torch.float32)).to(getattr(torch, out_dtype))
ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
......@@ -228,16 +215,13 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
assert_tl_matmul_correctness(
128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
if __name__ == "__main__":
......
......@@ -23,7 +23,6 @@ def tl_matmul(
b_preshuffle=False,
b_g2l_load=False,
):
micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}:
......@@ -53,18 +52,21 @@ def tl_matmul(
A_shape = (K, M) if a_transposed else (M, K)
if b_preshuffle:
B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y,
pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y,
pack_size_k, micro_size_y)
B_shape = (
(N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k)
if b_transposed
else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y)
)
else:
B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
if b_preshuffle:
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
pack_size_k) if b_transposed else (block_K // pack_size_k,
block_N // micro_size_y, pack_size_k,
micro_size_y)
B_shared_shape = (
(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k)
if b_transposed
else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y)
)
else:
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
......@@ -99,16 +101,17 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
})
}
)
num_ko = K // block_K
num_ki = block_K // (k_pack * micro_size_k)
......@@ -119,7 +122,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined(num_ko, num_stages=0):
# Load A into shared memory
if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared)
......@@ -129,20 +131,13 @@ def tl_matmul(
# Load B into shared memory
if b_g2l_load is False:
if b_transposed:
for j, k, jj, kk in T.Parallel(block_N // micro_size_y,
block_K // pack_size_k, micro_size_y,
pack_size_k):
B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j,
ko * block_K // pack_size_k + k, jj, kk]
for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k):
B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, ko * block_K // pack_size_k + k, jj, kk]
else:
for k, j, kk, jj in T.Parallel(block_K // pack_size_k,
block_N // micro_size_y, pack_size_k,
micro_size_y):
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
bx * block_N // micro_size_y + j, kk, jj]
for k, j, kk, jj in T.Parallel(block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y):
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, bx * block_N // micro_size_y + j, kk, jj]
for ki in T.serial(0, num_ki):
# Load A S2L
mfma_emitter.ldmatrix_a(
A_local,
......@@ -194,7 +189,8 @@ def shuffle_weight(
return x.contiguous()
def assert_tl_matmul_correctness(M,
def assert_tl_matmul_correctness(
M,
N,
K,
in_dtype,
......@@ -204,9 +200,9 @@ def assert_tl_matmul_correctness(M,
b_transposed=True,
k_pack=1,
b_preshuffle=False,
b_g2l_load=False):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack, b_preshuffle, b_g2l_load)
b_g2l_load=False,
):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
......@@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M,
if a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
B.to(torch.float32)).to(getattr(torch, out_dtype))
ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
......@@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M,
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"int8",
"int32",
b_transposed=False,
accum_dtype="int32",
k_pack=2,
b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"float8_e4m3fnuz",
"float32",
k_pack=2,
b_transposed=False,
b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_transposed=False, b_preshuffle=True)
if __name__ == "__main__":
......
......@@ -27,8 +27,7 @@ def matmul(
vec_size = 4 * k_pack
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
......@@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt():
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(
1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2)
@tilelang.testing.requires_rocm
......@@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32():
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(
1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2)
def matmul_rs(
......
......@@ -5,9 +5,7 @@ import pytest
@tilelang.jit
def simple_invalid_loop(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
def simple_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -28,9 +26,7 @@ def simple_invalid_loop(dtype: str = "bfloat16",
@tilelang.jit
def nested_invalid_loop(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
def nested_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -52,9 +48,7 @@ def nested_invalid_loop(dtype: str = "bfloat16",
@tilelang.jit
def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -75,9 +69,7 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
@tilelang.jit
def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
def valid_loop_not_use_loop_var(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -99,9 +91,7 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
@tilelang.jit
def valid_loop_not_frag(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
def valid_loop_not_frag(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -122,9 +112,7 @@ def valid_loop_not_frag(dtype: str = "bfloat16",
@tilelang.jit
def valid_loop_serial(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
def valid_loop_serial(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......
......@@ -30,7 +30,6 @@ Rule:
@tilelang.jit(out_idx=[1])
def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -46,7 +45,6 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -56,15 +54,13 @@ def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="fl
for i in T.Parallel(length // block1 // block2):
for j in T.Parallel(block1):
for k in T.Parallel(block2):
B[i * block1 * block2 + j * block2 +
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -103,8 +99,9 @@ is OK.
"""
def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats):
def matmul_nested_pipelines(
M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
......@@ -180,7 +177,8 @@ def run_gemm_nested_pipelines(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -193,8 +191,8 @@ def run_gemm_nested_pipelines(
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -218,7 +216,6 @@ is OK.
@tilelang.jit(out_idx=[1])
def nested_continuous_serials(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -234,7 +231,6 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_noncontinuous_serials(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -277,7 +273,6 @@ Rule:
@tilelang.jit(out_idx=[1])
def nested_continuous_sp(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -293,7 +288,6 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_continuous_ps(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -309,7 +303,6 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -319,15 +312,13 @@ def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"):
for i in T.Parallel(length // block1 // block2):
for j in T.serial(block1):
for k in T.Parallel(block2):
B[i * block1 * block2 + j * block2 +
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main
@tilelang.jit(out_idx=[1])
def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -337,8 +328,7 @@ def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"):
for i in T.serial(length // block1 // block2):
for j in T.Parallel(block1):
for k in T.serial(block2):
B[i * block1 * block2 + j * block2 +
k] = A[i * block1 * block2 + j * block2 + k] + 1.0
B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
return main
......@@ -505,7 +495,8 @@ def run_gemm_mixed_pp(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -514,8 +505,8 @@ def run_gemm_mixed_pp(
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -543,7 +534,8 @@ def run_gemm_mixed_pp(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
def test_mixed_pp():
......@@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel(
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -675,12 +668,12 @@ def run_gemm_tiled_op_with_parallel(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
@tilelang.jit(out_idx=[1])
def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -696,7 +689,6 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def customize_op_with_parallel(length=256, block=16, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......
......@@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False):
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
......@@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:
block_M = [64]
block_N = [64]
block_K = [32]
......@@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False):
num_stages,
thread_num,
enable_rasterization,
))
)
)
configs = [
{
......@@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False):
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
}
for c in _configs
]
return configs
......@@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
......@@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller):
return main
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
.set_compile_args(
out_idx=[-1],
target="auto",
).set_profile_args(
ref_prog=ref_program,)
)
.set_profile_args(
ref_prog=ref_program,
)
)
return autotuner.run(warmup=3, rep=20)
......
......@@ -30,30 +30,15 @@ def ref_program(A, B):
def get_configs():
iter_params = dict(
block_M=[64],
block_N=[64],
block_K=[32],
num_stages=[0, 1],
thread_num=[128],
enable_rasterization=[False])
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
block_M=128,
block_N=128,
block_K=32,
num_stages=0,
thread_num=128,
enable_rasterization=False):
iter_params = dict(block_M=[64], block_N=[64], block_K=[32], num_stages=[0, 1], thread_num=[128], enable_rasterization=[False])
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False):
dtype = "float16"
accum_dtype = "float"
......@@ -76,7 +61,6 @@ def matmul(M,
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
......
......@@ -63,6 +63,7 @@ def run_cache_matmul():
Reference PyTorch matrix multiplication for comparison.
"""
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.half) # Assuming dtype="float16" in matmul
return C
......
......@@ -29,9 +29,7 @@ class _cudaDeviceAttrNames:
def test_driver_get_device_properties():
prop = get_cuda_device_properties()
assert prop is not None, "Failed to get CUDA device properties"
assert isinstance(
prop,
torch.cuda._CudaDeviceProperties), ("Returned object is not of type _CudaDeviceProperties")
assert isinstance(prop, torch.cuda._CudaDeviceProperties), "Returned object is not of type _CudaDeviceProperties"
def test_device_get_device_name():
......@@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block():
def test_device_get_persisting_l2_cache_size():
tl_cache_size = get_persisting_l2_cache_max_size()
driver_cache_size = get_device_attribute(
_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize)
driver_cache_size = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize)
assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match"
......@@ -61,17 +58,14 @@ def test_device_get_num_sms():
def test_device_get_registers_per_block():
tl_regs_per_block = get_registers_per_block()
driver_regs_per_block = get_device_attribute(
_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock)
driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock)
assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match"
def test_device_get_max_dynamic_shared_size_bytes():
tl_dynamic_smem = get_max_dynamic_shared_size_bytes()
driver_dynamic_smem = get_device_attribute(
_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor)
assert tl_dynamic_smem == driver_dynamic_smem, (
"Max dynamic shared size bytes values do not match")
driver_dynamic_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor)
assert tl_dynamic_smem == driver_dynamic_smem, "Max dynamic shared size bytes values do not match"
if __name__ == "__main__":
......
......@@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch()
def gemm(M, N, K):
A = te.placeholder((M, K), name='A', dtype='float16')
B = te.placeholder((N, K), name='B', dtype='float16')
A = te.placeholder((M, K), name="A", dtype="float16")
B = te.placeholder((N, K), name="B", dtype="float16")
# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name='k')
k = te.reduce_axis((0, K), name="k")
C = te.compute(
(M, N),
lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
name='C')
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C")
return A, B, C
......@@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target)
print(tags)
policy = carver.TensorCorePolicy.from_prim_func(
func=tensorized_func, arch=arch, tags=tags, name="matmul_0")
policy = carver.TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags, name="matmul_0")
hints = policy.emit_config(topk=topk)
......@@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch()
def gemm(M, N, K):
A = te.placeholder((M, K), name='A', dtype='float16')
B = te.placeholder((N, K), name='B', dtype='float16')
A = te.placeholder((M, K), name="A", dtype="float16")
B = te.placeholder((N, K), name="B", dtype="float16")
# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name='k')
k = te.reduce_axis((0, K), name="k")
C = te.compute(
(M, N),
lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
name='C')
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C")
return A, B, C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment