Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,)
@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
B: T.Buffer((16,), "float32")):
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(16):
B[i] = A[shift + i]
......
......@@ -9,7 +9,6 @@ tilelang.testing.set_random_seed()
@tilelang.jit(out_idx=[1])
def parallel_elementwise_static(length=256, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -24,7 +23,6 @@ def parallel_elementwise_static(length=256, dtype="float32"):
@tilelang.jit(out_idx=[1])
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((max_len,), dtype),
......
......@@ -90,7 +90,8 @@ def run_gemm(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -103,8 +104,8 @@ def run_gemm(
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -124,17 +125,9 @@ def test_pipeline_order_stage():
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def blocksparse_matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
dtype="float16",
accum_dtype="float"):
},
)
def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
import tilelang.language as T
......@@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages):
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
kernel = blocksparse_matmul(
M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
kernel = blocksparse_matmul(M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
print(kernel.get_kernel_source())
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
......@@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if BlockMask[i, j, k]:
accu += (
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
].to(torch.float32)
ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
return ref_c
# Compute the reference result using the naive PyTorch implementation
......
......@@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type
def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
a_ptr: T.ptr,
......
......@@ -30,7 +30,8 @@ def run_reshape(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -187,9 +191,11 @@ def reshape_layout_transform_shared(N, M, dtype):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_mma_swizzle_layout(A_shared),
})
}
)
T.copy(A, A_shared)
A_shared_reshape = T.reshape(A_shared, [N])
T.copy(A_shared_reshape, B)
......@@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......
......@@ -4,9 +4,10 @@ import torch
import tilelang.testing
@tilelang.jit(out_idx=[1],)
@tilelang.jit(
out_idx=[1],
)
def tilelang_ternary(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -15,8 +16,7 @@ def tilelang_ternary(M, N, block_M, block_N, dtype="float16"):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = (
A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0)
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0
return main
......
......@@ -9,10 +9,8 @@ def ref_program(x, y):
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
......@@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
for local_y, local_x in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
......
......@@ -4,7 +4,6 @@ from tilelang import language as T
def test_unroll_with_step():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
......@@ -19,7 +18,6 @@ def test_unroll_with_step():
def test_unroll_with_unroll_factor():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
......
......@@ -4,17 +4,15 @@ import tilelang.testing
def test_var_assign() -> None:
@tilelang.jit(out_idx=-1)
def jit_kernel():
@T.prim_func
def test_var_assign(A: T.Tensor((2,), 'int32')):
def test_var_assign(A: T.Tensor((2,), "int32")):
with T.Kernel(1) as _:
a = T.alloc_var('int32', init=1)
b = T.alloc_var('int32', init=a) # b gets value of a
a = T.alloc_var("int32", init=1)
b = T.alloc_var("int32", init=a) # b gets value of a
a = 2
d = T.alloc_var('int32', init=a) # c gets new value of a
d = T.alloc_var("int32", init=a) # c gets new value of a
A[0] = b
A[1] = d
......@@ -28,5 +26,5 @@ def test_var_assign() -> None:
assert res[1] == 2
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -5,7 +5,6 @@ import tilelang.language as T
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test(N, M, stride_A, stride_B):
@T.prim_func
def main(
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
......@@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B):
code = jit_kernel.get_kernel_source()
vectorize_size = 1
while vectorize_size <= 2 and \
stride_A % (vectorize_size * 2) == 0 and \
stride_B % (vectorize_size * 2) == 0:
while vectorize_size <= 2 and stride_A % (vectorize_size * 2) == 0 and stride_B % (vectorize_size * 2) == 0:
vectorize_size *= 2
if vectorize_size == 4:
......@@ -61,7 +58,6 @@ def test_vectorize():
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test_invariant_index(N, M, K):
@T.prim_func
def main(
A: T.Tensor[(N, M), "float32"], # noqa: F821
......
......@@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()
assert check_str in code and check_str in code_parallel, \
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
def test_vectorized_cast():
......
......@@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None):
new_shape = [N // M, M]
if new_dtype:
from tvm import DataType
dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits
......@@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None):
def ref_program(A):
if new_dtype:
from tilelang.utils.tensor import map_torch_type
torch_dtype = map_torch_type(new_dtype)
return A.view(N // M, M).view(dtype=torch_dtype)
return A.view(N // M, M)
......@@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None):
def test_reshape_view():
# Test view with same dtype
run_view(1024, 32, "float32")
run_view(2048, 64, "float16")
......@@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
new_shape = [N // M, M + 1]
if new_dtype:
from tvm import DataType
dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits
......
......@@ -7,7 +7,6 @@ import tilelang.language as T
@tilelang.jit
def get_kernel(reduce_op: str, dtype: str):
assert reduce_op in ["sum", "max", "min", "bitand", "bitor"]
@T.prim_func
......@@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str):
def test_warp_reduce_sum():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel('sum', 'float32')
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("sum", "float32")
ref = torch.full_like(a, a.sum())
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_max():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel("max", 'float32')
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("max", "float32")
print(kernel.get_kernel_source())
ref = torch.full_like(a, a.max())
kernel(a)
......@@ -50,16 +49,16 @@ def test_warp_reduce_max():
def test_warp_reduce_min():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel("min", 'float32')
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("min", "float32")
ref = torch.full_like(a, a.min())
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_bitand():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
kernel = get_kernel("bitand", 'int32')
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitand", "int32")
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val & a[i]
......@@ -69,8 +68,8 @@ def test_warp_reduce_bitand():
def test_warp_reduce_bitor():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
kernel = get_kernel("bitor", 'int32')
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitor", "int32")
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val | a[i]
......
......@@ -12,7 +12,6 @@ VEC_SIZE = 32
@tilelang.jit
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func
def main(
a: T.Buffer((B, M, N), "bfloat16"),
......
......@@ -19,7 +19,6 @@ def bitwise_reduce(
func,
clear=True,
):
@T.prim_func
def reduce_func(
A: T.Tensor((M, N), "int32"),
......@@ -64,7 +63,7 @@ def run_single_bitwise_reduce(
row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row
# Column-based pattern: different bit positions set based on column
col_pattern = (1 << (j % 31)) # Single bit set at different positions
col_pattern = 1 << (j % 31) # Single bit set at different positions
# Combine patterns with XOR to create diverse bit distributions
# Add some deterministic "noise" based on position
......@@ -76,7 +75,7 @@ def run_single_bitwise_reduce(
if i % 4 == 0:
a[i, j] &= ~(0x1 << (i // 4))
elif i % 2 == 0:
a[i, j] |= (0x1 << (i // 2))
a[i, j] |= 0x1 << (i // 2)
if name == "reduce_bitand":
expected = torch.full((M,), -1, device="cuda", dtype=torch.int32)
......
......@@ -7,16 +7,16 @@ import re
def get_mathop_lines(source, mathop_name):
"""Extract lines containing the mathop from CUDA source for debugging"""
lines = source.split('\n')
lines = source.split("\n")
relevant_lines = []
for i, line in enumerate(lines):
if mathop_name in line and ('(' in line):
if mathop_name in line and ("(" in line):
# Include some context
start = max(0, i - 1)
end = min(len(lines), i + 2)
relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)])
relevant_lines.append("---")
return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
......@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches = re.findall(fastmath_pattern, source)
non_fastmath_matches = re.findall(non_fastmath_pattern, source)
print(
f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls"
)
print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls")
if len(fastmath_matches) > 0:
print(f"Fastmath calls found: {fastmath_matches}")
if len(non_fastmath_matches) > 0:
......@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
......@@ -70,8 +62,7 @@ def run_single_arg_mathop_test(mathop_name,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
......@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
......@@ -93,13 +85,7 @@ def run_single_arg_mathop_test(mathop_name,
print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
......@@ -112,9 +98,9 @@ def run_two_arg_mathop_test(mathop_name,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j])
C[by * block_M + i, bx * block_N + j] = mathop_func(
A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]
)
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
......@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
......@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
source_fastmath = kernel_fastmath.get_kernel_source()
......@@ -184,7 +172,8 @@ def run_abs_test():
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
source = kernel.get_kernel_source()
print("\n=== Testing abs (maps to fabs) ===")
......@@ -199,13 +188,7 @@ def run_abs_test():
print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
......@@ -217,8 +200,7 @@ def run_fastmath_mathop_test(mathop_name,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
......@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
source_fastmath = kernel_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} (fastmath version) ===")
print("FAST_MATH=True:")
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name = mathop_name.lstrip('_')
cuda_mathop_name = mathop_name.lstrip("_")
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness
......
......@@ -5,14 +5,7 @@ import tilelang.testing
import pytest
def run_ieee_math_test(mathop_name,
mathop_func,
rounding_mode="rn",
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test IEEE-compliant math operations with specified rounding modes.
"""
......@@ -29,11 +22,12 @@ def run_ieee_math_test(mathop_name,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
D[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
D[by * block_M + i, bx * block_N + j] = mathop_func(
A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j],
C[by * block_M + i,
bx * block_N + j], rounding_mode)
C[by * block_M + i, bx * block_N + j],
rounding_mode,
)
out_idx = [3]
num_inputs = 3
......@@ -47,10 +41,9 @@ def run_ieee_math_test(mathop_name,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
B[by * block_M + i,
bx * block_N + j], rounding_mode)
C[by * block_M + i, bx * block_N + j] = mathop_func(
A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j], rounding_mode
)
out_idx = [2]
num_inputs = 2
......@@ -63,9 +56,7 @@ def run_ieee_math_test(mathop_name,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
rounding_mode)
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], rounding_mode)
out_idx = [1]
num_inputs = 1
......@@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===")
print(f"✓ {mathop_name} compilation test passed")
......@@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only():
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
print("\n=== Testing ieee_frsqrt (rn only) ===")
print("✓ ieee_frsqrt compilation test passed")
......
......@@ -5,9 +5,8 @@ import tilelang.language as T
import torch
@tilelang.jit(execution_backend='torch')
@tilelang.jit(execution_backend="torch")
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
......@@ -15,8 +14,8 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="flo
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope='shared')
B_shared = T.alloc_shared((block_K, block_N), dtype, scope='shared')
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
......@@ -48,13 +47,13 @@ def assert_gemm(
torch_dtype = getattr(torch, dtype)
a, b = None, None
if 'int' in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps')
b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps')
if "int" in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps")
b = torch.randint(100, (K, N), dtype=torch_dtype, device="mps")
else:
a = torch.randn(M, K, dtype=torch_dtype, device='mps')
b = torch.randn(K, N, dtype=torch_dtype, device='mps')
c = torch.zeros(M, N, dtype=torch_dtype, device='mps')
a = torch.randn(M, K, dtype=torch_dtype, device="mps")
b = torch.randn(K, N, dtype=torch_dtype, device="mps")
c = torch.zeros(M, N, dtype=torch_dtype, device="mps")
jit_kernel(a, b, c)
......@@ -70,12 +69,12 @@ def test_gemm_float32():
@tilelang.testing.requires_metal
def test_gemm_float16():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='float16', atol=1)
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1)
@tilelang.testing.requires_metal
def test_gemm_int32():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='int32', atol=1)
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1)
if __name__ == "__main__":
......
......@@ -88,7 +88,8 @@ def run_matmul_ssr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -106,24 +107,9 @@ def run_matmul_ssr(
def test_gemm_f16f16f16_nt_ssr():
run_matmul_ssr(
16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(
128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
2,
num_threads=128)
run_matmul_ssr(16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(1024, 1024, 1024, False, True, "float16", "float16", "float16", 128, 128, 32, 2, num_threads=128)
def matmul_rsr(
......@@ -214,7 +200,8 @@ def run_matmul_rsr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -342,7 +329,8 @@ def run_matmul_rrr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment