Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,)
@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
B: T.Buffer((16,), "float32")):
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
T.func_attr({"tir.noalias": True})
for i in T.serial(16):
B[i] = A[shift + i]
......
......@@ -9,11 +9,10 @@ tilelang.testing.set_random_seed()
@tilelang.jit(out_idx=[1])
def parallel_elementwise_static(length=256, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length):
......@@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"):
@tilelang.jit(out_idx=[1])
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((max_len,), dtype),
B: T.Tensor((max_len,), dtype),
valid_len: T.int32,
A: T.Tensor((max_len,), dtype),
B: T.Tensor((max_len,), dtype),
valid_len: T.int32,
):
with T.Kernel(1, threads=threads) as _:
for i in T.Parallel(max_len):
......
......@@ -27,9 +27,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -90,7 +90,8 @@ def run_gemm(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -103,8 +104,8 @@ def run_gemm(
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -124,27 +125,19 @@ def test_pipeline_order_stage():
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def blocksparse_matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
dtype="float16",
accum_dtype="float"):
},
)
def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
import tilelang.language as T
@T.prim_func
def block_sparse_matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......@@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages):
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
kernel = blocksparse_matmul(
M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
kernel = blocksparse_matmul(M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
print(kernel.get_kernel_source())
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
......@@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if BlockMask[i, j, k]:
accu += (
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
].to(torch.float32)
ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
return ref_c
# Compute the reference result using the naive PyTorch implementation
......
......@@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type
def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
a_ptr: T.ptr,
......
......@@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_shared = T.alloc_shared((M, N), dtype)
......@@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
......@@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
......@@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
......@@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
......
......@@ -10,8 +10,8 @@ def reshape_test(N, M, dtype):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_reshaped = T.reshape(A, [N // M, M])
......@@ -30,7 +30,8 @@ def run_reshape(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_shared = T.alloc_shared((N,), dtype)
......@@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1) as _:
A_shared = T.alloc_shared((N // M, M), dtype)
......@@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype):
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
......@@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype):
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
A: T.Tensor((N // M, M), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N // M, M), dtype, scope="shared")
T.annotate_layout({
A_shared: make_mma_swizzle_layout(A_shared),
})
T.annotate_layout(
{
A_shared: make_mma_swizzle_layout(A_shared),
}
)
T.copy(A, A_shared)
A_shared_reshape = T.reshape(A_shared, [N])
T.copy(A_shared_reshape, B)
......@@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((N,), dtype, scope="shared")
......@@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype):
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = jit_kernel.get_profiler()
def ref_program(A):
......@@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_reshaped = T.reshape(A, [N // M, M + 1])
......
......@@ -4,19 +4,19 @@ import torch
import tilelang.testing
@tilelang.jit(out_idx=[1],)
@tilelang.jit(
out_idx=[1],
)
def tilelang_ternary(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = (
A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0)
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] if (by * block_M + i) < (M // 2) else 0
return main
......
......@@ -9,10 +9,8 @@ def ref_program(x, y):
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
......@@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
for local_y, local_x in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
......
......@@ -4,7 +4,6 @@ from tilelang import language as T
def test_unroll_with_step():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
......@@ -19,7 +18,6 @@ def test_unroll_with_step():
def test_unroll_with_unroll_factor():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
......
......@@ -4,17 +4,15 @@ import tilelang.testing
def test_var_assign() -> None:
@tilelang.jit(out_idx=-1)
def jit_kernel():
@T.prim_func
def test_var_assign(A: T.Tensor((2,), 'int32')):
def test_var_assign(A: T.Tensor((2,), "int32")):
with T.Kernel(1) as _:
a = T.alloc_var('int32', init=1)
b = T.alloc_var('int32', init=a) # b gets value of a
a = T.alloc_var("int32", init=1)
b = T.alloc_var("int32", init=a) # b gets value of a
a = 2
d = T.alloc_var('int32', init=a) # c gets new value of a
d = T.alloc_var("int32", init=a) # c gets new value of a
A[0] = b
A[1] = d
......@@ -28,5 +26,5 @@ def test_var_assign() -> None:
assert res[1] == 2
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -5,11 +5,10 @@ import tilelang.language as T
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test(N, M, stride_A, stride_B):
@T.prim_func
def main(
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821
):
with T.Kernel(M // 128, threads=128) as (bx):
tx = T.get_thread_binding(0)
......@@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B):
code = jit_kernel.get_kernel_source()
vectorize_size = 1
while vectorize_size <= 2 and \
stride_A % (vectorize_size * 2) == 0 and \
stride_B % (vectorize_size * 2) == 0:
while vectorize_size <= 2 and stride_A % (vectorize_size * 2) == 0 and stride_B % (vectorize_size * 2) == 0:
vectorize_size *= 2
if vectorize_size == 4:
......@@ -61,12 +58,11 @@ def test_vectorize():
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test_invariant_index(N, M, K):
@T.prim_func
def main(
A: T.Tensor[(N, M), "float32"], # noqa: F821
B: T.Tensor[(N, M), "float32"], # noqa: F821
C: T.Tensor[(N, M // K), "float32"], # noqa: F821
A: T.Tensor[(N, M), "float32"], # noqa: F821
B: T.Tensor[(N, M), "float32"], # noqa: F821
C: T.Tensor[(N, M // K), "float32"], # noqa: F821
):
with T.Kernel(N // 128, threads=128) as (bx):
tx = T.get_thread_binding(0)
......
......@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@T.prim_func
def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
T.copy(A, B)
......@@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@T.prim_func
def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
A_local = T.alloc_fragment((M,), dtype_A)
......@@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()
assert check_str in code and check_str in code_parallel, \
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
def test_vectorized_cast():
......
......@@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None):
new_shape = [N // M, M]
if new_dtype:
from tvm import DataType
dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits
......@@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
):
with T.Kernel(1) as _:
A_viewed = T.view(A, new_shape, dtype=new_dtype)
......@@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None):
def ref_program(A):
if new_dtype:
from tilelang.utils.tensor import map_torch_type
torch_dtype = map_torch_type(new_dtype)
return A.view(N // M, M).view(dtype=torch_dtype)
return A.view(N // M, M)
......@@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None):
def test_reshape_view():
# Test view with same dtype
run_view(1024, 32, "float32")
run_view(2048, 64, "float16")
......@@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
new_shape = [N // M, M + 1]
if new_dtype:
from tvm import DataType
dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits
......@@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
):
with T.Kernel(1) as _:
A_viewed = T.view(A, new_shape, dtype=new_dtype)
......
......@@ -7,7 +7,6 @@ import tilelang.language as T
@tilelang.jit
def get_kernel(reduce_op: str, dtype: str):
assert reduce_op in ["sum", "max", "min", "bitand", "bitor"]
@T.prim_func
......@@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str):
def test_warp_reduce_sum():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel('sum', 'float32')
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("sum", "float32")
ref = torch.full_like(a, a.sum())
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_max():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel("max", 'float32')
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("max", "float32")
print(kernel.get_kernel_source())
ref = torch.full_like(a, a.max())
kernel(a)
......@@ -50,16 +49,16 @@ def test_warp_reduce_max():
def test_warp_reduce_min():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel("min", 'float32')
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("min", "float32")
ref = torch.full_like(a, a.min())
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_bitand():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
kernel = get_kernel("bitand", 'int32')
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitand", "int32")
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val & a[i]
......@@ -69,8 +68,8 @@ def test_warp_reduce_bitand():
def test_warp_reduce_bitor():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
kernel = get_kernel("bitor", 'int32')
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitor", "int32")
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val | a[i]
......
......@@ -12,17 +12,16 @@ VEC_SIZE = 32
@tilelang.jit
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func
def main(
a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"),
a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"),
):
with T.Kernel(
T.ceildiv(M, BLOCK_MN),
T.ceildiv(N, BLOCK_K),
B,
threads=128,
T.ceildiv(M, BLOCK_MN),
T.ceildiv(N, BLOCK_K),
B,
threads=128,
) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
offs_m = pid_m * BLOCK_MN
......
......@@ -19,12 +19,11 @@ def bitwise_reduce(
func,
clear=True,
):
@T.prim_func
def reduce_func(
A: T.Tensor((M, N), "int32"),
B: T.Tensor((M), "int32"),
Output: T.Tensor((M), "int32"),
A: T.Tensor((M, N), "int32"),
B: T.Tensor((M), "int32"),
Output: T.Tensor((M), "int32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "int32")
......@@ -64,7 +63,7 @@ def run_single_bitwise_reduce(
row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row
# Column-based pattern: different bit positions set based on column
col_pattern = (1 << (j % 31)) # Single bit set at different positions
col_pattern = 1 << (j % 31) # Single bit set at different positions
# Combine patterns with XOR to create diverse bit distributions
# Add some deterministic "noise" based on position
......@@ -76,7 +75,7 @@ def run_single_bitwise_reduce(
if i % 4 == 0:
a[i, j] &= ~(0x1 << (i // 4))
elif i % 2 == 0:
a[i, j] |= (0x1 << (i // 2))
a[i, j] |= 0x1 << (i // 2)
if name == "reduce_bitand":
expected = torch.full((M,), -1, device="cuda", dtype=torch.int32)
......
......@@ -7,16 +7,16 @@ import re
def get_mathop_lines(source, mathop_name):
"""Extract lines containing the mathop from CUDA source for debugging"""
lines = source.split('\n')
lines = source.split("\n")
relevant_lines = []
for i, line in enumerate(lines):
if mathop_name in line and ('(' in line):
if mathop_name in line and ("(" in line):
# Include some context
start = max(0, i - 1)
end = min(len(lines), i + 2)
relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)])
relevant_lines.append("---")
return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
......@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches = re.findall(fastmath_pattern, source)
non_fastmath_matches = re.findall(non_fastmath_pattern, source)
print(
f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls"
)
print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls")
if len(fastmath_matches) > 0:
print(f"Fastmath calls found: {fastmath_matches}")
if len(non_fastmath_matches) > 0:
......@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
......@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
......@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
......@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j])
C[by * block_M + i, bx * block_N + j] = mathop_func(
A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]
)
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
......@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
......@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
source_fastmath = kernel_fastmath.get_kernel_source()
......@@ -171,8 +159,8 @@ def run_abs_test():
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
......@@ -184,7 +172,8 @@ def run_abs_test():
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
source = kernel.get_kernel_source()
print("\n=== Testing abs (maps to fabs) ===")
......@@ -199,26 +188,19 @@ def run_abs_test():
print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
......@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
source_fastmath = kernel_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} (fastmath version) ===")
print("FAST_MATH=True:")
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name = mathop_name.lstrip('_')
cuda_mathop_name = mathop_name.lstrip("_")
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness
......
......@@ -5,14 +5,7 @@ import tilelang.testing
import pytest
def run_ieee_math_test(mathop_name,
mathop_func,
rounding_mode="rn",
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test IEEE-compliant math operations with specified rounding modes.
"""
......@@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name,
@T.prim_func
def main_func(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
D: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
D: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
D[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j],
C[by * block_M + i,
bx * block_N + j], rounding_mode)
D[by * block_M + i, bx * block_N + j] = mathop_func(
A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j],
C[by * block_M + i, bx * block_N + j],
rounding_mode,
)
out_idx = [3]
num_inputs = 3
......@@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name,
@T.prim_func
def main_func(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
B[by * block_M + i,
bx * block_N + j], rounding_mode)
C[by * block_M + i, bx * block_N + j] = mathop_func(
A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j], rounding_mode
)
out_idx = [2]
num_inputs = 2
......@@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name,
@T.prim_func
def main_func(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
rounding_mode)
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], rounding_mode)
out_idx = [1]
num_inputs = 1
......@@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===")
print(f"✓ {mathop_name} compilation test passed")
......@@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only():
@T.prim_func
def main(
A: T.Tensor((128, 128), "float32"),
B: T.Tensor((128, 128), "float32"),
A: T.Tensor((128, 128), "float32"),
B: T.Tensor((128, 128), "float32"),
):
with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by):
for i, j in T.Parallel(32, 32):
......@@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only():
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
print("\n=== Testing ieee_frsqrt (rn only) ===")
print("✓ ieee_frsqrt compilation test passed")
......
......@@ -5,18 +5,17 @@ import tilelang.language as T
import torch
@tilelang.jit(execution_backend='torch')
@tilelang.jit(execution_backend="torch")
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope='shared')
B_shared = T.alloc_shared((block_K, block_N), dtype, scope='shared')
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
......@@ -48,13 +47,13 @@ def assert_gemm(
torch_dtype = getattr(torch, dtype)
a, b = None, None
if 'int' in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps')
b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps')
if "int" in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps")
b = torch.randint(100, (K, N), dtype=torch_dtype, device="mps")
else:
a = torch.randn(M, K, dtype=torch_dtype, device='mps')
b = torch.randn(K, N, dtype=torch_dtype, device='mps')
c = torch.zeros(M, N, dtype=torch_dtype, device='mps')
a = torch.randn(M, K, dtype=torch_dtype, device="mps")
b = torch.randn(K, N, dtype=torch_dtype, device="mps")
c = torch.zeros(M, N, dtype=torch_dtype, device="mps")
jit_kernel(a, b, c)
......@@ -70,12 +69,12 @@ def test_gemm_float32():
@tilelang.testing.requires_metal
def test_gemm_float16():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='float16', atol=1)
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1)
@tilelang.testing.requires_metal
def test_gemm_int32():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype='int32', atol=1)
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1)
if __name__ == "__main__":
......
......@@ -27,9 +27,9 @@ def matmul_ssr(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
......@@ -88,7 +88,8 @@ def run_matmul_ssr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -106,24 +107,9 @@ def run_matmul_ssr(
def test_gemm_f16f16f16_nt_ssr():
run_matmul_ssr(
16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(
128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
2,
num_threads=128)
run_matmul_ssr(16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(1024, 1024, 1024, False, True, "float16", "float16", "float16", 128, 128, 32, 2, num_threads=128)
def matmul_rsr(
......@@ -151,9 +137,9 @@ def matmul_rsr(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
......@@ -214,7 +200,8 @@ def run_matmul_rsr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -276,9 +263,9 @@ def matmul_rrr(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -342,7 +329,8 @@ def run_matmul_rrr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment