Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -5,13 +5,14 @@ import torch
def test_tensor_annot_mul():
@tilelang.jit
def example_tensor_annot():
n = T.symbolic('n')
n = T.symbolic("n")
@T.prim_func
def kernel(A: T.Tensor((n * 4,), T.int32),):
def kernel(
A: T.Tensor((n * 4,), T.int32),
):
with T.Kernel(1) as _:
for i in range(n * 4):
A[i] = 0
......@@ -19,20 +20,21 @@ def test_tensor_annot_mul():
return kernel
ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda')
A = torch.arange(16, dtype=torch.int32, device="cuda")
ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda')
expected = torch.zeros(16, dtype=torch.int32, device="cuda")
assert torch.equal(A, expected)
def test_tensor_annot_add():
@tilelang.jit
def example_tensor_annot():
n = T.symbolic('n')
n = T.symbolic("n")
@T.prim_func
def kernel(A: T.Tensor((n + 1,), T.int32),):
def kernel(
A: T.Tensor((n + 1,), T.int32),
):
with T.Kernel(1) as _:
for i in range(n + 1):
A[i] = 0
......@@ -40,20 +42,21 @@ def test_tensor_annot_add():
return kernel
ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda')
A = torch.arange(16, dtype=torch.int32, device="cuda")
ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda')
expected = torch.zeros(16, dtype=torch.int32, device="cuda")
assert torch.equal(A, expected)
def test_tensor_annot_mul_add():
@tilelang.jit
def example_tensor_annot():
n = T.symbolic('n')
n = T.symbolic("n")
@T.prim_func
def kernel(A: T.Tensor((n * 3 + 1,), T.int32),):
def kernel(
A: T.Tensor((n * 3 + 1,), T.int32),
):
with T.Kernel(1) as _:
for i in range(n * 3 + 1):
A[i] = 0
......@@ -61,11 +64,11 @@ def test_tensor_annot_mul_add():
return kernel
ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda')
A = torch.arange(16, dtype=torch.int32, device="cuda")
ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda')
expected = torch.zeros(16, dtype=torch.int32, device="cuda")
assert torch.equal(A, expected)
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......@@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0):
program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
ref_b = torch.zeros_like(a)
......
......@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if torch.any(BlockMask[i, j, k]):
accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32)
ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = (
accu.to(torch.float16))
accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
].to(torch.float32)
ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
return ref_c
......@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......
......@@ -4,10 +4,9 @@ import tilelang.testing
def test_assume_remove_boundary_check():
@tilelang.jit
def kernel_with_assume():
N = T.dynamic('N')
N = T.dynamic("N")
@T.prim_func
def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32):
......@@ -21,20 +20,19 @@ def test_assume_remove_boundary_check():
jit_kernel = kernel_with_assume()
source = jit_kernel.get_kernel_source()
assert ("if (" not in source)
assert "if (" not in source
def test_assume_enable_vectorization():
@tilelang.jit
def kernel_vectorize(M):
N = T.dynamic('N')
N = T.dynamic("N")
vectorize_size = 4
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding()
......@@ -55,16 +53,15 @@ def test_assume_enable_vectorization():
def test_assume_complex_indexing():
@tilelang.jit
def kernel_complex():
M = T.dynamic('M')
N = T.dynamic('N')
M = T.dynamic("M")
N = T.dynamic("N")
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding()
......@@ -82,8 +79,8 @@ def test_assume_complex_indexing():
jit_kernel = kernel_complex()
source = jit_kernel.get_kernel_source()
assert ("if (" not in source)
assert "if (" not in source
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -4,14 +4,12 @@ import tilelang.language as T
@tilelang.jit
def atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
......@@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit
def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
T.atomic_add(B[bx * block_M, by * block_N], A_shared)
......@@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit
def atomic_max_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
......@@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit
def atomic_min_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
......@@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
B[i, j] = min(B[i, j], A[k, i, j])
A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda()
B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda()
ref_B = B.clone()
ref_program(A, ref_B)
kernel(A, B)
......@@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit
def atomic_load_store_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
......@@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
@tilelang.jit
def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(
B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed")
T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed")
return atomic_with_memory_order
......@@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit
def atomic_addx2_program(M, N, block_M, block_N):
@T.prim_func
def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
......@@ -262,10 +249,10 @@ def test_atomic_addx2():
@tilelang.jit
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_different_orders(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor(
(M, N), dtype), D: T.Tensor((M, N), dtype)):
def atomic_different_orders(
A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype)
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
idx_i = bx * block_M + i
......@@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
D = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda()
D = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda()
kernel(A, B, C, D)
torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(C, torch.maximum(torch.zeros_like(A), A))
torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A))
torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float("inf")), A))
@tilelang.jit
def atomic_addx4_program(M, N, block_M, block_N):
@T.prim_func
def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
......@@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N):
@tilelang.jit
def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
old_vals: T.Tensor((M, N), dtype)):
def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
idx_i = bx * block_M + i
idx_j = by * block_N + j
if idx_i < M and idx_j < N:
old_vals[idx_i, idx_j] = T.atomic_add(
B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True)
old_vals[idx_i, idx_j] = T.atomic_add(B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True)
return atomic_with_return_prev
......
......@@ -5,7 +5,6 @@ import torch
@tilelang.jit(out_idx=[-1])
def _ceildiv_kernel(a: int, b: int):
@T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32")):
with T.Kernel(1, threads=1) as _:
......@@ -30,7 +29,6 @@ def test_ceildiv():
@tilelang.jit
def _ceildiv_kernel_dyn(b: int):
@T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32):
with T.Kernel(1, threads=1) as _:
......
......@@ -8,14 +8,14 @@ import torch
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},)
},
)
def chain_equal(N, block_size, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx:
for lane in T.Parallel(block_size):
......
......@@ -13,8 +13,8 @@ def clamp_within_bounds(
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
......@@ -56,8 +56,8 @@ def clamp_value_range(
@T.prim_func
def main(
A: T.Tensor((1, N), dtype),
B: T.Tensor((1, N), dtype),
A: T.Tensor((1, N), dtype),
B: T.Tensor((1, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype)
......
......@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......@@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile(
program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True})
kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True})
import torch
from tilelang.utils import map_torch_type
a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda()
b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda()
c = kernel(a, b)
......
......@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M * N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M * N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......@@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b.flatten(), a.flatten(), rtol=1e-2, atol=1e-2)
......
......@@ -7,11 +7,10 @@ import tilelang.testing
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......@@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......@@ -43,11 +40,10 @@ def test_tilelang_copy():
def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.StridedTensor((M, N), (NN, 1), dtype),
B: T.Tensor((M, N), dtype),
A: T.StridedTensor((M, N), (NN, 1), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......@@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
return main
def run_tilelang_copy_with_stride(M=1024,
N=1024,
NN=2048,
block_M=128,
block_N=128,
dtype="float16"):
def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype="float16"):
if isinstance(NN, int):
assert NN > N, "NN must be greater than N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
......@@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
},
)
if isinstance(NN, T.Var):
NN = N * 2
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
......@@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride():
def tilelang_copy_bufferload(num_tokens, dtype="float16"):
@T.prim_func
def main(
indices: T.Tensor((num_tokens,), "int32"),
x: T.Tensor((num_tokens,), dtype),
indices: T.Tensor((num_tokens,), "int32"),
x: T.Tensor((num_tokens,), dtype),
):
with T.Kernel(num_tokens, threads=32) as pid:
idx = T.alloc_local([1], "int32")
......@@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
tilelang.compile(
program,
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
)
def test_tilelang_copy_bufferload():
......@@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload():
def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......@@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float
return main
def run_tilelang_copy_buffer_load_with_parallel(M=1024,
N=1024,
block_M=128,
block_N=128,
dtype="float16"):
def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......
......@@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3
@T.prim_func
def cumsum(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......@@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl
@T.prim_func
def cumsum(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......@@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
ref_b = torch.empty_like(A)
for i in range(M // block_M):
for j in range(N // block_N):
ref_b[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = A[i * block_M:(i + 1) * block_M, j *
block_N:(j + 1) * block_N].cumsum(dim=dim)
ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = A[
i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N
].cumsum(dim=dim)
if reverse:
ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
block_N] = A[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
block_N].flip(dims=[dim]).cumsum(dim=dim).flip(dims=[dim])
ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = (
A[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N]
.flip(dims=[dim])
.cumsum(dim=dim)
.flip(dims=[dim])
)
return ref_b
tilelang_res = jit_kernel(A)
......@@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
@T.prim_func
def cumsum(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype)
......@@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
@T.prim_func
def cumsum(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype)
......
......@@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var
def test_argument():
@T.prim_func
def test_argument(
t_1: T.bool,
......@@ -41,6 +40,7 @@ def test_argument():
def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes
errors = []
for name in _all_dtypes:
dtype = getattr(T, name)
......@@ -116,33 +116,32 @@ def test_expr():
def test_dtype_str_repr():
@T.prim_func
def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope="shared") # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope="shared") # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope="shared") # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope="shared") # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope="shared") # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope="shared") # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope="shared") # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope="shared") # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope="shared") # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope="shared") # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope="shared") # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope="shared") # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope="shared") # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope="shared") # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope="shared") # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope="shared") # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope="shared") # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope="shared") # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope="shared") # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope="shared") # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope="shared") # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope="shared") # noqa F841
# not supported now
......@@ -205,7 +204,6 @@ def test_dtype_str_repr():
def test_var_assign():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)):
......@@ -223,7 +221,6 @@ def test_var_assign():
def test_marco_return():
@T.macro
def macro_return_constant():
return 0
......@@ -258,11 +255,10 @@ def test_marco_return():
def test_prim_func_generator():
@T.prim_func(generator=True)
def prim_func_gen(
A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008
A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008
):
with T.Kernel(128) as (tx,):
T.copy(A[tx], B[tx])
......@@ -277,7 +273,6 @@ def test_prim_func_generator():
def test_serial_for_with_step():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)):
......@@ -291,7 +286,7 @@ def test_serial_for_with_step():
ker = test_stepped_serial()
res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda')
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device="cuda")
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
@tilelang.jit(out_idx=-1)
......@@ -304,17 +299,16 @@ def test_serial_for_with_step():
ker = test_serial_step_neg()
res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda')
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device="cuda")
assert torch.all(res == ref), f"Expected {ref}, but got {res}"
assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm("int32", 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var("tmp", "int32")), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)
def test_swap_logic():
@tilelang.jit
@T.prim_func
def swap_var(A: T.Tensor[(2,), T.float32]):
......@@ -344,7 +338,6 @@ def test_swap_logic():
def test_while_loop():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_while_loop(A: T.Tensor((1,), T.int32)):
......@@ -374,7 +367,7 @@ def test_var_macro():
x = T.alloc_var(T.int32)
macro_with_var(x)
assert 'x[0] = 1' in prim_call_macro.script()
assert "x[0] = 1" in prim_call_macro.script()
finally:
pass
......@@ -406,7 +399,7 @@ def test_var_macro():
x = T.alloc_var(T.int32)
macro_with_var(x)
assert 'x[0] = 1' in prim_call_macro.script()
assert "x[0] = 1" in prim_call_macro.script()
finally:
pass
......@@ -428,10 +421,8 @@ def test_var_macro():
def test_frame_inside_macro():
@tilelang.jit
def get_sample_kernel():
@T.macro
def transform(x):
return x + 1
......@@ -442,7 +433,7 @@ def test_frame_inside_macro():
idx_out: T.Tensor[(32,), T.int32],
):
with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841
fragment = T.alloc_fragment(32, 'int32')
fragment = T.alloc_fragment(32, "int32")
T.copy(idx_out, fragment)
for i in T.Parallel(32):
......@@ -467,10 +458,10 @@ def test_buffer_slice_step():
def test_boolop():
a = Var('a', 'int32')
b = Var('b', 'int32')
c = Var('c', 'int32')
d = Var('d', 'int32')
a = Var("a", "int32")
b = Var("b", "int32")
c = Var("c", "int32")
d = Var("d", "int32")
@T.macro
def cond():
......@@ -479,5 +470,5 @@ def test_boolop():
cond()
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
@tilelang.jit(out_idx=[-1])
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
......@@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@tilelang.jit(out_idx=[-1])
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
......@@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] =
@tilelang.jit(out_idx=[-1])
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
......@@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel(
warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None,
):
@T.prim_func
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
......@@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel(
@tilelang.jit(out_idx=[-1])
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):
@T.prim_func
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
......
......@@ -4,13 +4,14 @@ import torch
import tilelang.testing
@tilelang.jit(out_idx=[1],)
@tilelang.jit(
out_idx=[1],
)
def tilelang_if_range(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -5,7 +5,6 @@ import tilelang.language as T
@tilelang.jit(out_idx=-1)
def get_inf_kernel(dtype: str):
@T.prim_func
def main(A: T.Tensor((32,), dtype)):
with T.Kernel(1, threads=32):
......@@ -18,7 +17,7 @@ def _test_infinity(dtype: str):
kernel = get_inf_kernel(dtype)
output = kernel()
assert torch.all(output == torch.inf), f'check failed for {dtype=}'
assert torch.all(output == torch.inf), f"check failed for {dtype=}"
@tilelang.testing.requires_cuda
......
......@@ -9,8 +9,8 @@ def test_language_ldg_codegen():
@T.prim_func
def main(
x: T.Tensor((N,), "float32"),
y: T.Tensor((N,), "float32"),
x: T.Tensor((N,), "float32"),
y: T.Tensor((N,), "float32"),
):
with T.Kernel(N, threads=32) as pid:
# Explicitly request read-only cache load for x[pid]
......
......@@ -8,7 +8,6 @@ import torch
def _gemm_impl():
@T.macro
def gemm_impl(
A: T.Tensor[[int, int], Any],
......@@ -37,7 +36,6 @@ def _gemm_impl():
def test_jit2_gemm_annot():
@tilelang.lazy_jit
def gemm(
A: T.Tensor[[int, int], Any],
......@@ -54,24 +52,24 @@ def test_jit2_gemm_annot():
return C
prod = product([T.float16, T.float32], [T.float32])
gemm.par_compile([{
'A': T.Tensor((1024, 1024), dtype=in_dtype),
'B': T.Tensor((1024, 1024), dtype=in_dtype),
'out_dtype': out_dtype
} for in_dtype, out_dtype in prod])
gemm.par_compile(
[
{"A": T.Tensor((1024, 1024), dtype=in_dtype), "B": T.Tensor((1024, 1024), dtype=in_dtype), "out_dtype": out_dtype}
for in_dtype, out_dtype in prod
]
)
for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
C_ref = out_dtype(A @ B)
C = gemm(A, B)
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)
def test_jit2_gemm_ptr():
@tilelang.lazy_jit
def gemm_ptr(
A: T.ptr,
......@@ -92,23 +90,19 @@ def test_jit2_gemm_ptr():
_gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K)
prod = product([T.float16, T.float32], [T.float32])
gemm_ptr.par_compile([{
'A': T.ptr(),
'B': T.ptr(),
'C': T.ptr(),
'M': 1024,
'N': 1024,
'K': 1024,
'dtype': in_dtype,
'out_dtype': out_dtype
} for in_dtype, out_dtype in prod])
gemm_ptr.par_compile(
[
{"A": T.ptr(), "B": T.ptr(), "C": T.ptr(), "M": 1024, "N": 1024, "K": 1024, "dtype": in_dtype, "out_dtype": out_dtype}
for in_dtype, out_dtype in prod
]
)
for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
C_ref = out_dtype(A @ B)
C = torch.empty(1024, 1024, dtype=out_dtype, device='cuda')
C = torch.empty(1024, 1024, dtype=out_dtype, device="cuda")
gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype)
torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2)
......@@ -129,8 +123,7 @@ def test_jit2_annot():
AnnotTest(
annot=T.Tensor[[int, int], T.float32],
promote=False,
match_ok=[torch.randn(1, 1, dtype=torch.float32),
T.Tensor((1, 1), dtype=T.float32)],
match_ok=[torch.randn(1, 1, dtype=torch.float32), T.Tensor((1, 1), dtype=T.float32)],
match_ng=[
torch.randn(1, 1, dtype=torch.float16),
T.Tensor(1, dtype=T.float32),
......@@ -146,8 +139,8 @@ def test_jit2_annot():
T.Tensor((1,), dtype=T.float32),
T.Tensor((1,), dtype=T.float16),
],
match_ng=[torch.randn((1, 1), dtype=torch.float32),
T.Tensor((1, 1), dtype=T.float16)]),
match_ng=[torch.randn((1, 1), dtype=torch.float32), T.Tensor((1, 1), dtype=T.float16)],
),
AnnotTest(
annot=T.Tensor[[int, 1], Any],
promote=False,
......@@ -157,8 +150,8 @@ def test_jit2_annot():
T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16),
],
match_ng=[torch.randn(12, 12, dtype=torch.float32),
T.Tensor((12, 12), T.float32)]),
match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)],
),
AnnotTest(
annot=T.Tensor[[T.dyn, 1], Any],
promote=False,
......@@ -168,43 +161,39 @@ def test_jit2_annot():
T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16),
],
match_ng=[torch.randn(12, 12, dtype=torch.float32),
T.Tensor((12, 12), T.float32)]),
match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)],
),
AnnotTest(
annot=T.Tensor[[1024, 1024], T.float32],
promote=True,
),
AnnotTest(annot=T.dyn[int, 'X'], promote=False, match_ok=[1, 2, 3, 4]),
AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4])
AnnotTest(annot=T.dyn[int, "X"], promote=False, match_ok=[1, 2, 3, 4]),
AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]),
]
for test in tests:
promote = test.annot.promote()
promoted = promote is not None
if promoted != test.promote:
raise AssertionError(
f'Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}')
with Builder().prim_func('_test'):
raise AssertionError(f"Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}")
with Builder().prim_func("_test"):
for match_ok in test.match_ok:
try:
vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ok, vt)
test.annot.create_prim_func_arg("arg", match_ok, vt)
except Exception as e:
traceback.print_exc()
raise AssertionError(
f'Match failed for {test.annot} with value {match_ok}: {e}') from e
raise AssertionError(f"Match failed for {test.annot} with value {match_ok}: {e}") from e
for match_ng in test.match_ng:
try:
vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ng, vt)
raise AssertionError(
f'Match unexpectedly succeeded for {test.annot} with value {match_ng}')
test.annot.create_prim_func_arg("arg", match_ng, vt)
raise AssertionError(f"Match unexpectedly succeeded for {test.annot} with value {match_ng}")
except Exception:
pass
def test_jit2_many_annot():
@T.macro
def copy_impl(A, B):
M, N = A.shape
......@@ -213,8 +202,7 @@ def test_jit2_many_annot():
assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128,
by * 128:by * 128 + 128])
T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128])
@tilelang.lazy_jit
def copy1(
......@@ -259,20 +247,19 @@ def test_jit2_many_annot():
copy_impl(A, B)
for copy in [copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda')
B = torch.empty(128, 128, device='cuda')
A = torch.randn(128, 128, device="cuda")
B = torch.empty(128, 128, device="cuda")
copy(A, B)
assert torch.equal(B, A)
for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda')
B = torch.randn(128, 2, 128, 2, device='cuda')
A = torch.randn(128, 2, 128, 2, device="cuda")
B = torch.randn(128, 2, 128, 2, device="cuda")
copy(A[:, 0, :, 0], B[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0])
def test_jit2_return():
@T.macro
def copy_impl(A):
M, N = A.shape
......@@ -283,8 +270,7 @@ def test_jit2_return():
assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128,
by * 128:by * 128 + 128])
T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128])
return B
@tilelang.lazy_jit
......@@ -292,41 +278,52 @@ def test_jit2_return():
return copy_impl(A)
@tilelang.lazy_jit
def copy1(A: T.Tensor[[int, int], T.float32],):
def copy1(
A: T.Tensor[[int, int], T.float32],
):
return copy_impl(A)
@tilelang.lazy_jit
def copy2(A: T.Tensor[[128, 128], T.float32],):
def copy2(
A: T.Tensor[[128, 128], T.float32],
):
return copy_impl(A)
@tilelang.lazy_jit
def copy3(A: T.Tensor[[int, 128], T.float32],):
def copy3(
A: T.Tensor[[int, 128], T.float32],
):
return copy_impl(A)
@tilelang.lazy_jit
def copy4(A: T.Tensor[[T.dyn, int], T.float32],):
def copy4(
A: T.Tensor[[T.dyn, int], T.float32],
):
return copy_impl(A)
@tilelang.lazy_jit
def copy5(A: T.StridedTensor[[int, int], [int, int], T.float32],):
def copy5(
A: T.StridedTensor[[int, int], [int, int], T.float32],
):
return copy_impl(A)
@tilelang.lazy_jit
def copy6(A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],):
def copy6(
A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],
):
return copy_impl(A)
for copy in [copy0, copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda')
A = torch.randn(128, 128, device="cuda")
B = copy(A)
assert torch.equal(B, A)
for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda')
A = torch.randn(128, 2, 128, 2, device="cuda")
B = copy(A[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B)
def test_jit2_deepseek_deepgemm():
@tilelang.lazy_jit
def deep_gemm(
A: T.Tensor[[int, int], T.float8_e4m3],
......@@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm():
N, K = B.shape
C = T.empty(M, N, dtype=out_dtype)
assert out_dtype in [
T.bfloat16, T.float32
], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}"
assert scales_a.shape == [M, T.ceildiv(K, group_size)
], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}"
assert scales_b.shape == [N, T.ceildiv(K, group_size)
], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}"
assert out_dtype in [T.bfloat16, T.float32], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}"
assert scales_a.shape == [M, T.ceildiv(K, group_size)], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}"
assert scales_b.shape == [N, T.ceildiv(K, group_size)], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}"
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), in_dtype)
......@@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm():
# M, N, K = 1024, 1024, 8192
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -4,7 +4,6 @@ from tilelang import language as T
def test_let_vectorize_load():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
......
......@@ -6,11 +6,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......@@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......@@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......@@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......@@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......@@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
return main
def run_tilelang_copy_mask_parallel_range(M=1024,
N=1024,
block_M=128,
block_N=128,
dtype="float16"):
def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......@@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......@@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment