Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,) ...@@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,)
@T.prim_func @T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True}) T.func_attr({"tir.noalias": True})
for i in T.serial(16): for i in T.serial(16):
B[i] = A[shift + i] B[i] = A[shift + i]
......
...@@ -9,11 +9,10 @@ tilelang.testing.set_random_seed() ...@@ -9,11 +9,10 @@ tilelang.testing.set_random_seed()
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def parallel_elementwise_static(length=256, dtype="float32"): def parallel_elementwise_static(length=256, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype), B: T.Tensor((length,), dtype),
): ):
with T.Kernel(1, threads=length) as _: with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length): for i in T.Parallel(length):
...@@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"): ...@@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"): def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((max_len,), dtype), A: T.Tensor((max_len,), dtype),
B: T.Tensor((max_len,), dtype), B: T.Tensor((max_len,), dtype),
valid_len: T.int32, valid_len: T.int32,
): ):
with T.Kernel(1, threads=threads) as _: with T.Kernel(1, threads=threads) as _:
for i in T.Parallel(max_len): for i in T.Parallel(max_len):
......
...@@ -27,9 +27,9 @@ def matmul( ...@@ -27,9 +27,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -90,7 +90,8 @@ def run_gemm( ...@@ -90,7 +90,8 @@ def run_gemm(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
...@@ -103,8 +104,8 @@ def run_gemm( ...@@ -103,8 +104,8 @@ def run_gemm(
if in_dtype == "float32": if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas # float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
...@@ -124,27 +125,19 @@ def test_pipeline_order_stage(): ...@@ -124,27 +125,19 @@ def test_pipeline_order_stage():
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
def blocksparse_matmul(M, )
N, def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"):
K,
block_M,
block_N,
block_K,
num_stages,
dtype="float16",
accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K) block_mask_shape = (M // block_M, N // block_N, K // block_K)
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
def block_sparse_matmul( def block_sparse_matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages): ...@@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages):
a = torch.randn(M, K).cuda().half() a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half() b = torch.randn(K, N).cuda().half()
kernel = blocksparse_matmul( kernel = blocksparse_matmul(M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
# Create block mask with desired sparsity # Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
...@@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages): ...@@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K): for k in range(K // block_K):
if BlockMask[i, j, k]: if BlockMask[i, j, k]:
accu += ( accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
torch.float32) @ B[k * block_K:(k + 1) * block_K, ].to(torch.float32)
j * block_N:(j + 1) * block_N].to(torch.float32)) ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
return ref_c return ref_c
# Compute the reference result using the naive PyTorch implementation # Compute the reference result using the naive PyTorch implementation
......
...@@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type ...@@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type
def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
a_ptr: T.ptr, a_ptr: T.ptr,
......
...@@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb): ...@@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype), B: T.Tensor((M,), dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_shared = T.alloc_shared((M, N), dtype) A_shared = T.alloc_shared((M, N), dtype)
...@@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"): ...@@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype), B: T.Tensor((M,), dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype) A_local = T.alloc_fragment((M, N), dtype)
...@@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"): ...@@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype), B: T.Tensor((M,), dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype) A_local = T.alloc_fragment((M, N), dtype)
...@@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"): ...@@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype), B: T.Tensor((M,), dtype),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype) A_local = T.alloc_fragment((M, N), dtype)
...@@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"): ...@@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype), B: T.Tensor((M,), dtype),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype) A_local = T.alloc_fragment((M, N), dtype)
......
...@@ -10,8 +10,8 @@ def reshape_test(N, M, dtype): ...@@ -10,8 +10,8 @@ def reshape_test(N, M, dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype), B: T.Tensor((N // M, M), dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_reshaped = T.reshape(A, [N // M, M]) A_reshaped = T.reshape(A, [N // M, M])
...@@ -30,7 +30,8 @@ def run_reshape(N, M, dtype): ...@@ -30,7 +30,8 @@ def run_reshape(N, M, dtype):
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
...@@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype): ...@@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype), B: T.Tensor((N // M, M), dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_shared = T.alloc_shared((N,), dtype) A_shared = T.alloc_shared((N,), dtype)
...@@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype): ...@@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
...@@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype): ...@@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N // M, M), dtype), A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_shared = T.alloc_shared((N // M, M), dtype) A_shared = T.alloc_shared((N // M, M), dtype)
...@@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype): ...@@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
...@@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype): ...@@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N // M, M), dtype), A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
...@@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype): ...@@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype):
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
...@@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype): ...@@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N // M, M), dtype), A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared") A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
T.annotate_layout({ T.annotate_layout(
A_shared: make_mma_swizzle_layout(A_shared), {
}) A_shared: make_mma_swizzle_layout(A_shared),
}
)
T.copy(A, A_shared) T.copy(A, A_shared)
A_shared_reshape = T.reshape(A_shared, [N]) A_shared_reshape = T.reshape(A_shared, [N])
T.copy(A_shared_reshape, B) T.copy(A_shared_reshape, B)
...@@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype): ...@@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype):
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
...@@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype): ...@@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N // M,), dtype), B: T.Tensor((N // M,), dtype),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N,), dtype, scope="shared") A_shared = T.alloc_shared((N,), dtype, scope="shared")
...@@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype): ...@@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype):
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
...@@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype): ...@@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype), B: T.Tensor((N // M, M), dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_reshaped = T.reshape(A, [N // M, M + 1]) A_reshaped = T.reshape(A, [N // M, M + 1])
......
...@@ -4,19 +4,19 @@ import torch ...@@ -4,19 +4,19 @@ import torch
import tilelang.testing import tilelang.testing
@tilelang.jit(out_idx=[1],) @tilelang.jit(
out_idx=[1],
)
def tilelang_ternary(M, N, block_M, block_N, dtype="float16"): def tilelang_ternary(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = ( B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0
A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0)
return main return main
......
...@@ -9,10 +9,8 @@ def ref_program(x, y): ...@@ -9,10 +9,8 @@ def ref_program(x, y):
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func @T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype) A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype)
...@@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): ...@@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared) T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N): for local_y, local_x in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N]) T.copy(C_shared, C[by * block_M, bx * block_N])
......
...@@ -4,7 +4,6 @@ from tilelang import language as T ...@@ -4,7 +4,6 @@ from tilelang import language as T
def test_unroll_with_step(): def test_unroll_with_step():
@T.prim_func @T.prim_func
def main(A_ptr: T.handle): def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
...@@ -19,7 +18,6 @@ def test_unroll_with_step(): ...@@ -19,7 +18,6 @@ def test_unroll_with_step():
def test_unroll_with_unroll_factor(): def test_unroll_with_unroll_factor():
@T.prim_func @T.prim_func
def main(A_ptr: T.handle): def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
......
...@@ -4,17 +4,15 @@ import tilelang.testing ...@@ -4,17 +4,15 @@ import tilelang.testing
def test_var_assign() -> None: def test_var_assign() -> None:
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
def jit_kernel(): def jit_kernel():
@T.prim_func @T.prim_func
def test_var_assign(A: T.Tensor((2,), 'int32')): def test_var_assign(A: T.Tensor((2,), "int32")):
with T.Kernel(1) as _: with T.Kernel(1) as _:
a = T.alloc_var('int32', init=1) a = T.alloc_var("int32", init=1)
b = T.alloc_var('int32', init=a) # b gets value of a b = T.alloc_var("int32", init=a) # b gets value of a
a = 2 a = 2
d = T.alloc_var('int32', init=a) # c gets new value of a d = T.alloc_var("int32", init=a) # c gets new value of a
A[0] = b A[0] = b
A[1] = d A[1] = d
...@@ -28,5 +26,5 @@ def test_var_assign() -> None: ...@@ -28,5 +26,5 @@ def test_var_assign() -> None:
assert res[1] == 2 assert res[1] == 2
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -5,11 +5,10 @@ import tilelang.language as T ...@@ -5,11 +5,10 @@ import tilelang.language as T
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test(N, M, stride_A, stride_B): def vectorize_test(N, M, stride_A, stride_B):
@T.prim_func @T.prim_func
def main( def main(
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821
): ):
with T.Kernel(M // 128, threads=128) as (bx): with T.Kernel(M // 128, threads=128) as (bx):
tx = T.get_thread_binding(0) tx = T.get_thread_binding(0)
...@@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B): ...@@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B):
code = jit_kernel.get_kernel_source() code = jit_kernel.get_kernel_source()
vectorize_size = 1 vectorize_size = 1
while vectorize_size <= 2 and \ while vectorize_size <= 2 and stride_A % (vectorize_size * 2) == 0 and stride_B % (vectorize_size * 2) == 0:
stride_A % (vectorize_size * 2) == 0 and \
stride_B % (vectorize_size * 2) == 0:
vectorize_size *= 2 vectorize_size *= 2
if vectorize_size == 4: if vectorize_size == 4:
...@@ -61,12 +58,11 @@ def test_vectorize(): ...@@ -61,12 +58,11 @@ def test_vectorize():
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test_invariant_index(N, M, K): def vectorize_test_invariant_index(N, M, K):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(N, M), "float32"], # noqa: F821 A: T.Tensor[(N, M), "float32"], # noqa: F821
B: T.Tensor[(N, M), "float32"], # noqa: F821 B: T.Tensor[(N, M), "float32"], # noqa: F821
C: T.Tensor[(N, M // K), "float32"], # noqa: F821 C: T.Tensor[(N, M // K), "float32"], # noqa: F821
): ):
with T.Kernel(N // 128, threads=128) as (bx): with T.Kernel(N // 128, threads=128) as (bx):
tx = T.get_thread_binding(0) tx = T.get_thread_binding(0)
......
...@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): ...@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821 A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821 B: T.Tensor[(M,), dtype_B], # noqa: F821
): ):
with T.Kernel(1, threads=128): with T.Kernel(1, threads=128):
T.copy(A, B) T.copy(A, B)
...@@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): ...@@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821 A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821 B: T.Tensor[(M,), dtype_B], # noqa: F821
): ):
with T.Kernel(1, threads=128): with T.Kernel(1, threads=128):
A_local = T.alloc_fragment((M,), dtype_A) A_local = T.alloc_fragment((M,), dtype_A)
...@@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, ...@@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
code = kernel.get_kernel_source() code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source() code_parallel = kernel_parallel.get_kernel_source()
assert check_str in code and check_str in code_parallel, \ assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
def test_vectorized_cast(): def test_vectorized_cast():
......
...@@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None): ...@@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None):
new_shape = [N // M, M] new_shape = [N // M, M]
if new_dtype: if new_dtype:
from tvm import DataType from tvm import DataType
dtype_src = DataType(dtype) dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype) dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits src_bits = dtype_src.bits
...@@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None): ...@@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_viewed = T.view(A, new_shape, dtype=new_dtype) A_viewed = T.view(A, new_shape, dtype=new_dtype)
...@@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None): ...@@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None):
def ref_program(A): def ref_program(A):
if new_dtype: if new_dtype:
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
torch_dtype = map_torch_type(new_dtype) torch_dtype = map_torch_type(new_dtype)
return A.view(N // M, M).view(dtype=torch_dtype) return A.view(N // M, M).view(dtype=torch_dtype)
return A.view(N // M, M) return A.view(N // M, M)
...@@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None): ...@@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None):
def test_reshape_view(): def test_reshape_view():
# Test view with same dtype # Test view with same dtype
run_view(1024, 32, "float32") run_view(1024, 32, "float32")
run_view(2048, 64, "float16") run_view(2048, 64, "float16")
...@@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None): ...@@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
new_shape = [N // M, M + 1] new_shape = [N // M, M + 1]
if new_dtype: if new_dtype:
from tvm import DataType from tvm import DataType
dtype_src = DataType(dtype) dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype) dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits src_bits = dtype_src.bits
...@@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None): ...@@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype), B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
): ):
with T.Kernel(1) as _: with T.Kernel(1) as _:
A_viewed = T.view(A, new_shape, dtype=new_dtype) A_viewed = T.view(A, new_shape, dtype=new_dtype)
......
...@@ -7,7 +7,6 @@ import tilelang.language as T ...@@ -7,7 +7,6 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def get_kernel(reduce_op: str, dtype: str): def get_kernel(reduce_op: str, dtype: str):
assert reduce_op in ["sum", "max", "min", "bitand", "bitor"] assert reduce_op in ["sum", "max", "min", "bitand", "bitor"]
@T.prim_func @T.prim_func
...@@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str): ...@@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str):
def test_warp_reduce_sum(): def test_warp_reduce_sum():
a = torch.randn((32,), dtype=torch.float32, device='cuda') a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel('sum', 'float32') kernel = get_kernel("sum", "float32")
ref = torch.full_like(a, a.sum()) ref = torch.full_like(a, a.sum())
kernel(a) kernel(a)
torch.testing.assert_close(a, ref) torch.testing.assert_close(a, ref)
def test_warp_reduce_max(): def test_warp_reduce_max():
a = torch.randn((32,), dtype=torch.float32, device='cuda') a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("max", 'float32') kernel = get_kernel("max", "float32")
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
ref = torch.full_like(a, a.max()) ref = torch.full_like(a, a.max())
kernel(a) kernel(a)
...@@ -50,16 +49,16 @@ def test_warp_reduce_max(): ...@@ -50,16 +49,16 @@ def test_warp_reduce_max():
def test_warp_reduce_min(): def test_warp_reduce_min():
a = torch.randn((32,), dtype=torch.float32, device='cuda') a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("min", 'float32') kernel = get_kernel("min", "float32")
ref = torch.full_like(a, a.min()) ref = torch.full_like(a, a.min())
kernel(a) kernel(a)
torch.testing.assert_close(a, ref) torch.testing.assert_close(a, ref)
def test_warp_reduce_bitand(): def test_warp_reduce_bitand():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitand", 'int32') kernel = get_kernel("bitand", "int32")
ref_val = a[0] ref_val = a[0]
for i in range(1, a.shape[0]): for i in range(1, a.shape[0]):
ref_val = ref_val & a[i] ref_val = ref_val & a[i]
...@@ -69,8 +68,8 @@ def test_warp_reduce_bitand(): ...@@ -69,8 +68,8 @@ def test_warp_reduce_bitand():
def test_warp_reduce_bitor(): def test_warp_reduce_bitor():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitor", 'int32') kernel = get_kernel("bitor", "int32")
ref_val = a[0] ref_val = a[0]
for i in range(1, a.shape[0]): for i in range(1, a.shape[0]):
ref_val = ref_val | a[i] ref_val = ref_val | a[i]
......
...@@ -12,17 +12,16 @@ VEC_SIZE = 32 ...@@ -12,17 +12,16 @@ VEC_SIZE = 32
@tilelang.jit @tilelang.jit
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func @T.prim_func
def main( def main(
a: T.Buffer((B, M, N), "bfloat16"), a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"), a_out: T.Buffer((B, M, N), "float32"),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(M, BLOCK_MN), T.ceildiv(M, BLOCK_MN),
T.ceildiv(N, BLOCK_K), T.ceildiv(N, BLOCK_K),
B, B,
threads=128, threads=128,
) as (pid_m, pid_n, pid_b): ) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32") a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
offs_m = pid_m * BLOCK_MN offs_m = pid_m * BLOCK_MN
......
...@@ -19,12 +19,11 @@ def bitwise_reduce( ...@@ -19,12 +19,11 @@ def bitwise_reduce(
func, func,
clear=True, clear=True,
): ):
@T.prim_func @T.prim_func
def reduce_func( def reduce_func(
A: T.Tensor((M, N), "int32"), A: T.Tensor((M, N), "int32"),
B: T.Tensor((M), "int32"), B: T.Tensor((M), "int32"),
Output: T.Tensor((M), "int32"), Output: T.Tensor((M), "int32"),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "int32") A_shared = T.alloc_shared((block_M, block_N), "int32")
...@@ -64,7 +63,7 @@ def run_single_bitwise_reduce( ...@@ -64,7 +63,7 @@ def run_single_bitwise_reduce(
row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row
# Column-based pattern: different bit positions set based on column # Column-based pattern: different bit positions set based on column
col_pattern = (1 << (j % 31)) # Single bit set at different positions col_pattern = 1 << (j % 31) # Single bit set at different positions
# Combine patterns with XOR to create diverse bit distributions # Combine patterns with XOR to create diverse bit distributions
# Add some deterministic "noise" based on position # Add some deterministic "noise" based on position
...@@ -76,7 +75,7 @@ def run_single_bitwise_reduce( ...@@ -76,7 +75,7 @@ def run_single_bitwise_reduce(
if i % 4 == 0: if i % 4 == 0:
a[i, j] &= ~(0x1 << (i // 4)) a[i, j] &= ~(0x1 << (i // 4))
elif i % 2 == 0: elif i % 2 == 0:
a[i, j] |= (0x1 << (i // 2)) a[i, j] |= 0x1 << (i // 2)
if name == "reduce_bitand": if name == "reduce_bitand":
expected = torch.full((M,), -1, device="cuda", dtype=torch.int32) expected = torch.full((M,), -1, device="cuda", dtype=torch.int32)
......
...@@ -7,16 +7,16 @@ import re ...@@ -7,16 +7,16 @@ import re
def get_mathop_lines(source, mathop_name): def get_mathop_lines(source, mathop_name):
"""Extract lines containing the mathop from CUDA source for debugging""" """Extract lines containing the mathop from CUDA source for debugging"""
lines = source.split('\n') lines = source.split("\n")
relevant_lines = [] relevant_lines = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if mathop_name in line and ('(' in line): if mathop_name in line and ("(" in line):
# Include some context # Include some context
start = max(0, i - 1) start = max(0, i - 1)
end = min(len(lines), i + 2) end = min(len(lines), i + 2)
relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)])
relevant_lines.append("---") relevant_lines.append("---")
return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
def check_fastmath_usage(source, mathop_name, expect_fastmath=False): def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
...@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False): ...@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches = re.findall(fastmath_pattern, source) fastmath_matches = re.findall(fastmath_pattern, source)
non_fastmath_matches = re.findall(non_fastmath_pattern, source) non_fastmath_matches = re.findall(non_fastmath_pattern, source)
print( print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls")
f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls"
)
if len(fastmath_matches) > 0: if len(fastmath_matches) > 0:
print(f"Fastmath calls found: {fastmath_matches}") print(f"Fastmath calls found: {fastmath_matches}")
if len(non_fastmath_matches) > 0: if len(non_fastmath_matches) > 0:
...@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): ...@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False) check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name, def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
""" """
Test single-argument mathops. Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name, ...@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
bx * block_N + j])
# Test with FAST_MATH disabled # Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile( kernel_no_fastmath = tilelang.compile(
...@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name, ...@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_no_fastmath = kernel_no_fastmath.get_kernel_source()
...@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name, ...@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print(f"✓ {mathop_name} compilation and execution test passed") print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name, def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
""" """
Test two-argument mathops to ensure they generate non-fastmath CUDA code. Test two-argument mathops to ensure they generate non-fastmath CUDA code.
""" """
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, C[by * block_M + i, bx * block_N + j] = mathop_func(
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]
B[by * block_M + i, bx * block_N + j]) )
# Test with FAST_MATH disabled # Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile( kernel_no_fastmath = tilelang.compile(
...@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name, ...@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
# Test with FAST_MATH enabled # Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile( kernel_fastmath = tilelang.compile(
...@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name, ...@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_no_fastmath = kernel_no_fastmath.get_kernel_source()
source_fastmath = kernel_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source()
...@@ -171,8 +159,8 @@ def run_abs_test(): ...@@ -171,8 +159,8 @@ def run_abs_test():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32"),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -184,7 +172,8 @@ def run_abs_test(): ...@@ -184,7 +172,8 @@ def run_abs_test():
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
source = kernel.get_kernel_source() source = kernel.get_kernel_source()
print("\n=== Testing abs (maps to fabs) ===") print("\n=== Testing abs (maps to fabs) ===")
...@@ -199,26 +188,19 @@ def run_abs_test(): ...@@ -199,26 +188,19 @@ def run_abs_test():
print("✓ abs numerical test passed") print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name, def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
""" """
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
""" """
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
bx * block_N + j])
# Test with FAST_MATH enabled # Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile( kernel_fastmath = tilelang.compile(
...@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name, ...@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
source_fastmath = kernel_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} (fastmath version) ===") print(f"\n=== Testing {mathop_name} (fastmath version) ===")
print("FAST_MATH=True:") print("FAST_MATH=True:")
# Strip the __ prefix for checking in the CUDA source # Strip the __ prefix for checking in the CUDA source
cuda_mathop_name = mathop_name.lstrip('_') cuda_mathop_name = mathop_name.lstrip("_")
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness # Test numerical correctness
......
...@@ -5,14 +5,7 @@ import tilelang.testing ...@@ -5,14 +5,7 @@ import tilelang.testing
import pytest import pytest
def run_ieee_math_test(mathop_name, def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"):
mathop_func,
rounding_mode="rn",
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
""" """
Test IEEE-compliant math operations with specified rounding modes. Test IEEE-compliant math operations with specified rounding modes.
""" """
...@@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name, ...@@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name,
@T.prim_func @T.prim_func
def main_func( def main_func(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
D: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
D[by * block_M + i, D[by * block_M + i, bx * block_N + j] = mathop_func(
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j],
C[by * block_M + i, C[by * block_M + i, bx * block_N + j],
bx * block_N + j], rounding_mode) rounding_mode,
)
out_idx = [3] out_idx = [3]
num_inputs = 3 num_inputs = 3
...@@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name, ...@@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name,
@T.prim_func @T.prim_func
def main_func( def main_func(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, C[by * block_M + i, bx * block_N + j] = mathop_func(
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j], rounding_mode
B[by * block_M + i, )
bx * block_N + j], rounding_mode)
out_idx = [2] out_idx = [2]
num_inputs = 2 num_inputs = 2
...@@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name, ...@@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name,
@T.prim_func @T.prim_func
def main_func( def main_func(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], rounding_mode)
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
rounding_mode)
out_idx = [1] out_idx = [1]
num_inputs = 1 num_inputs = 1
...@@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name, ...@@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===") print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===")
print(f"✓ {mathop_name} compilation test passed") print(f"✓ {mathop_name} compilation test passed")
...@@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only(): ...@@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((128, 128), "float32"), A: T.Tensor((128, 128), "float32"),
B: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32"),
): ):
with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by): with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by):
for i, j in T.Parallel(32, 32): for i, j in T.Parallel(32, 32):
...@@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only(): ...@@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only():
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
print("\n=== Testing ieee_frsqrt (rn only) ===") print("\n=== Testing ieee_frsqrt (rn only) ===")
print("✓ ieee_frsqrt compilation test passed") print("✓ ieee_frsqrt compilation test passed")
......
...@@ -5,18 +5,17 @@ import tilelang.language as T ...@@ -5,18 +5,17 @@ import tilelang.language as T
import torch import torch
@tilelang.jit(execution_backend='torch') @tilelang.jit(execution_backend="torch")
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm( def gemm(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope='shared') A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope='shared') B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
...@@ -48,13 +47,13 @@ def assert_gemm( ...@@ -48,13 +47,13 @@ def assert_gemm(
torch_dtype = getattr(torch, dtype) torch_dtype = getattr(torch, dtype)
a, b = None, None a, b = None, None
if 'int' in dtype: if "int" in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps') a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps")
b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps') b = torch.randint(100, (K, N), dtype=torch_dtype, device="mps")
else: else:
a = torch.randn(M, K, dtype=torch_dtype, device='mps') a = torch.randn(M, K, dtype=torch_dtype, device="mps")
b = torch.randn(K, N, dtype=torch_dtype, device='mps') b = torch.randn(K, N, dtype=torch_dtype, device="mps")
c = torch.zeros(M, N, dtype=torch_dtype, device='mps') c = torch.zeros(M, N, dtype=torch_dtype, device="mps")
jit_kernel(a, b, c) jit_kernel(a, b, c)
...@@ -70,12 +69,12 @@ def test_gemm_float32(): ...@@ -70,12 +69,12 @@ def test_gemm_float32():
@tilelang.testing.requires_metal @tilelang.testing.requires_metal
def test_gemm_float16(): def test_gemm_float16():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='float16', atol=1) assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1)
@tilelang.testing.requires_metal @tilelang.testing.requires_metal
def test_gemm_int32(): def test_gemm_int32():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='int32', atol=1) assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -27,9 +27,9 @@ def matmul_ssr( ...@@ -27,9 +27,9 @@ def matmul_ssr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
...@@ -88,7 +88,8 @@ def run_matmul_ssr( ...@@ -88,7 +88,8 @@ def run_matmul_ssr(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
...@@ -106,24 +107,9 @@ def run_matmul_ssr( ...@@ -106,24 +107,9 @@ def run_matmul_ssr(
def test_gemm_f16f16f16_nt_ssr(): def test_gemm_f16f16f16_nt_ssr():
run_matmul_ssr( run_matmul_ssr(16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32) run_matmul_ssr(128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr( run_matmul_ssr(1024, 1024, 1024, False, True, "float16", "float16", "float16", 128, 128, 32, 2, num_threads=128)
128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
2,
num_threads=128)
def matmul_rsr( def matmul_rsr(
...@@ -151,9 +137,9 @@ def matmul_rsr( ...@@ -151,9 +137,9 @@ def matmul_rsr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
...@@ -214,7 +200,8 @@ def run_matmul_rsr( ...@@ -214,7 +200,8 @@ def run_matmul_rsr(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
...@@ -276,9 +263,9 @@ def matmul_rrr( ...@@ -276,9 +263,9 @@ def matmul_rrr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -342,7 +329,8 @@ def run_matmul_rrr( ...@@ -342,7 +329,8 @@ def run_matmul_rrr(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment