"docs/en_US/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "cc58a81dc83c5ca469a713866406944b924507c0"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -5,13 +5,14 @@ import torch ...@@ -5,13 +5,14 @@ import torch
def test_tensor_annot_mul(): def test_tensor_annot_mul():
@tilelang.jit @tilelang.jit
def example_tensor_annot(): def example_tensor_annot():
n = T.symbolic('n') n = T.symbolic("n")
@T.prim_func @T.prim_func
def kernel(A: T.Tensor((n * 4,), T.int32),): def kernel(
A: T.Tensor((n * 4,), T.int32),
):
with T.Kernel(1) as _: with T.Kernel(1) as _:
for i in range(n * 4): for i in range(n * 4):
A[i] = 0 A[i] = 0
...@@ -19,20 +20,21 @@ def test_tensor_annot_mul(): ...@@ -19,20 +20,21 @@ def test_tensor_annot_mul():
return kernel return kernel
ker = example_tensor_annot() ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda') A = torch.arange(16, dtype=torch.int32, device="cuda")
ker(A) ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda') expected = torch.zeros(16, dtype=torch.int32, device="cuda")
assert torch.equal(A, expected) assert torch.equal(A, expected)
def test_tensor_annot_add(): def test_tensor_annot_add():
@tilelang.jit @tilelang.jit
def example_tensor_annot(): def example_tensor_annot():
n = T.symbolic('n') n = T.symbolic("n")
@T.prim_func @T.prim_func
def kernel(A: T.Tensor((n + 1,), T.int32),): def kernel(
A: T.Tensor((n + 1,), T.int32),
):
with T.Kernel(1) as _: with T.Kernel(1) as _:
for i in range(n + 1): for i in range(n + 1):
A[i] = 0 A[i] = 0
...@@ -40,20 +42,21 @@ def test_tensor_annot_add(): ...@@ -40,20 +42,21 @@ def test_tensor_annot_add():
return kernel return kernel
ker = example_tensor_annot() ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda') A = torch.arange(16, dtype=torch.int32, device="cuda")
ker(A) ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda') expected = torch.zeros(16, dtype=torch.int32, device="cuda")
assert torch.equal(A, expected) assert torch.equal(A, expected)
def test_tensor_annot_mul_add(): def test_tensor_annot_mul_add():
@tilelang.jit @tilelang.jit
def example_tensor_annot(): def example_tensor_annot():
n = T.symbolic('n') n = T.symbolic("n")
@T.prim_func @T.prim_func
def kernel(A: T.Tensor((n * 3 + 1,), T.int32),): def kernel(
A: T.Tensor((n * 3 + 1,), T.int32),
):
with T.Kernel(1) as _: with T.Kernel(1) as _:
for i in range(n * 3 + 1): for i in range(n * 3 + 1):
A[i] = 0 A[i] = 0
...@@ -61,11 +64,11 @@ def test_tensor_annot_mul_add(): ...@@ -61,11 +64,11 @@ def test_tensor_annot_mul_add():
return kernel return kernel
ker = example_tensor_annot() ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda') A = torch.arange(16, dtype=torch.int32, device="cuda")
ker(A) ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda') expected = torch.zeros(16, dtype=torch.int32, device="cuda")
assert torch.equal(A, expected) assert torch.equal(A, expected)
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -7,11 +7,10 @@ import torch ...@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): ...@@ -30,13 +29,8 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0): def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0):
program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value) program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
out_idx=[1], )
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
ref_b = torch.zeros_like(a) ref_b = torch.zeros_like(a)
......
...@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): ...@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K): for k in range(K // block_K):
if torch.any(BlockMask[i, j, k]): if torch.any(BlockMask[i, j, k]):
accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
torch.float32) @ B[k * block_K:(k + 1) * block_K, k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
j * block_N:(j + 1) * block_N].to(torch.float32) ].to(torch.float32)
ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = ( ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
accu.to(torch.float16))
return ref_c return ref_c
...@@ -35,15 +34,14 @@ def blocksparse_matmul_global( ...@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
dtype="float16", dtype="float16",
accum_dtype="float", accum_dtype="float",
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -80,15 +78,14 @@ def blocksparse_matmul_shared( ...@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
dtype="float16", dtype="float16",
accum_dtype="float", accum_dtype="float",
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -130,15 +127,14 @@ def blocksparse_matmul_local( ...@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
dtype="float16", dtype="float16",
accum_dtype="float", accum_dtype="float",
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi ...@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
# Create block mask with desired sparsity # Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity block_mask = torch.rand(mask_shape).cuda() > sparsity
...@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio ...@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
# Create block mask with desired sparsity # Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity block_mask = torch.rand(mask_shape).cuda() > sparsity
......
...@@ -4,10 +4,9 @@ import tilelang.testing ...@@ -4,10 +4,9 @@ import tilelang.testing
def test_assume_remove_boundary_check(): def test_assume_remove_boundary_check():
@tilelang.jit @tilelang.jit
def kernel_with_assume(): def kernel_with_assume():
N = T.dynamic('N') N = T.dynamic("N")
@T.prim_func @T.prim_func
def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32): def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32):
...@@ -21,20 +20,19 @@ def test_assume_remove_boundary_check(): ...@@ -21,20 +20,19 @@ def test_assume_remove_boundary_check():
jit_kernel = kernel_with_assume() jit_kernel = kernel_with_assume()
source = jit_kernel.get_kernel_source() source = jit_kernel.get_kernel_source()
assert ("if (" not in source) assert "if (" not in source
def test_assume_enable_vectorization(): def test_assume_enable_vectorization():
@tilelang.jit @tilelang.jit
def kernel_vectorize(M): def kernel_vectorize(M):
N = T.dynamic('N') N = T.dynamic("N")
vectorize_size = 4 vectorize_size = 4
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32"),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding() tid = T.get_thread_binding()
...@@ -55,16 +53,15 @@ def test_assume_enable_vectorization(): ...@@ -55,16 +53,15 @@ def test_assume_enable_vectorization():
def test_assume_complex_indexing(): def test_assume_complex_indexing():
@tilelang.jit @tilelang.jit
def kernel_complex(): def kernel_complex():
M = T.dynamic('M') M = T.dynamic("M")
N = T.dynamic('N') N = T.dynamic("N")
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32"),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding() tid = T.get_thread_binding()
...@@ -82,8 +79,8 @@ def test_assume_complex_indexing(): ...@@ -82,8 +79,8 @@ def test_assume_complex_indexing():
jit_kernel = kernel_complex() jit_kernel = kernel_complex()
source = jit_kernel.get_kernel_source() source = jit_kernel.get_kernel_source()
assert ("if (" not in source) assert "if (" not in source
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -4,14 +4,12 @@ import tilelang.language as T ...@@ -4,14 +4,12 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): def atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype) A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
A_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
...@@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): ...@@ -39,14 +37,12 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"): def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype) A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
A_shared)
T.atomic_add(B[bx * block_M, by * block_N], A_shared) T.atomic_add(B[bx * block_M, by * block_N], A_shared)
...@@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"): ...@@ -76,14 +72,12 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_max_program(K, M, N, block_M, block_N, dtype="float"): def atomic_max_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype) A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
A_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) T.atomic_max(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
...@@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"): ...@@ -111,14 +105,12 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_min_program(K, M, N, block_M, block_N, dtype="float"): def atomic_min_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype) A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
A_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j]) T.atomic_min(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
...@@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): ...@@ -137,7 +129,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
B[i, j] = min(B[i, j], A[k, i, j]) B[i, j] = min(B[i, j], A[k, i, j])
A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda() B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda()
ref_B = B.clone() ref_B = B.clone()
ref_program(A, ref_B) ref_program(A, ref_B)
kernel(A, B) kernel(A, B)
...@@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): ...@@ -146,7 +138,6 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_load_store_program(M, N, block_M, block_N, dtype="float"): def atomic_load_store_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
...@@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"): ...@@ -172,18 +163,15 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"): def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype) A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared)
A_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
T.atomic_add( T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed")
B[bx * block_M + i, by * block_N + j], A_shared[i, j], memory_order="relaxed")
return atomic_with_memory_order return atomic_with_memory_order
...@@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"): ...@@ -208,7 +196,6 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_addx2_program(M, N, block_M, block_N): def atomic_addx2_program(M, N, block_M, block_N):
@T.prim_func @T.prim_func
def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")): def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
...@@ -262,10 +249,10 @@ def test_atomic_addx2(): ...@@ -262,10 +249,10 @@ def test_atomic_addx2():
@tilelang.jit @tilelang.jit
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_different_orders(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor( def atomic_different_orders(
(M, N), dtype), D: T.Tensor((M, N), dtype)): A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype)
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
idx_i = bx * block_M + i idx_i = bx * block_M + i
...@@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): ...@@ -286,18 +273,17 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda() A = torch.randn(M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
D = torch.full((M, N), float('inf'), dtype=getattr(torch, dtype)).cuda() D = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda()
kernel(A, B, C, D) kernel(A, B, C, D)
torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3) torch.testing.assert_close(B, A, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(C, torch.maximum(torch.zeros_like(A), A)) torch.testing.assert_close(C, torch.maximum(torch.zeros_like(A), A))
torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float('inf')), A)) torch.testing.assert_close(D, torch.minimum(torch.full_like(A, float("inf")), A))
@tilelang.jit @tilelang.jit
def atomic_addx4_program(M, N, block_M, block_N): def atomic_addx4_program(M, N, block_M, block_N):
@T.prim_func @T.prim_func
def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")): def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
...@@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N): ...@@ -330,17 +316,14 @@ def run_atomic_addx4(M, N, block_M, block_N):
@tilelang.jit @tilelang.jit
def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"): def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"):
@T.prim_func @T.prim_func
def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)):
old_vals: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
idx_i = bx * block_M + i idx_i = bx * block_M + i
idx_j = by * block_N + j idx_j = by * block_N + j
if idx_i < M and idx_j < N: if idx_i < M and idx_j < N:
old_vals[idx_i, idx_j] = T.atomic_add( old_vals[idx_i, idx_j] = T.atomic_add(B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True)
B[idx_i, idx_j], A[idx_i, idx_j], return_prev=True)
return atomic_with_return_prev return atomic_with_return_prev
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _ceildiv_kernel(a: int, b: int): def _ceildiv_kernel(a: int, b: int):
@T.prim_func @T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32")): def ceildiv_kernel(A: T.Tensor((1,), "int32")):
with T.Kernel(1, threads=1) as _: with T.Kernel(1, threads=1) as _:
...@@ -30,7 +29,6 @@ def test_ceildiv(): ...@@ -30,7 +29,6 @@ def test_ceildiv():
@tilelang.jit @tilelang.jit
def _ceildiv_kernel_dyn(b: int): def _ceildiv_kernel_dyn(b: int):
@T.prim_func @T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32): def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32):
with T.Kernel(1, threads=1) as _: with T.Kernel(1, threads=1) as _:
......
...@@ -8,14 +8,14 @@ import torch ...@@ -8,14 +8,14 @@ import torch
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},) },
)
def chain_equal(N, block_size, dtype="float32"): def chain_equal(N, block_size, dtype="float32"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx: with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx:
for lane in T.Parallel(block_size): for lane in T.Parallel(block_size):
......
...@@ -13,8 +13,8 @@ def clamp_within_bounds( ...@@ -13,8 +13,8 @@ def clamp_within_bounds(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype) A_shared = T.alloc_shared([block_N], dtype)
...@@ -56,8 +56,8 @@ def clamp_value_range( ...@@ -56,8 +56,8 @@ def clamp_value_range(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((1, N), dtype), A: T.Tensor((1, N), dtype),
B: T.Tensor((1, N), dtype), B: T.Tensor((1, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype) # A_shared = T.alloc_shared([1, block_N], dtype=dtype)
......
...@@ -5,12 +5,11 @@ import tilelang.language as T ...@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -42,10 +41,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile( kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True})
program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True})
import torch import torch
from tilelang.utils import map_torch_type from tilelang.utils import map_torch_type
a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda() a = torch.randn((M, K), dtype=map_torch_type(dtype)).cuda()
b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda() b = torch.randn((N, K), dtype=map_torch_type(dtype)).cuda()
c = kernel(a, b) c = kernel(a, b)
......
...@@ -7,11 +7,10 @@ import torch ...@@ -7,11 +7,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M * N), dtype), B: T.Tensor((M * N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype ...@@ -35,7 +34,8 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b.flatten(), a.flatten(), rtol=1e-2, atol=1e-2) torch.testing.assert_close(b.flatten(), a.flatten(), rtol=1e-2, atol=1e-2)
......
...@@ -7,11 +7,10 @@ import tilelang.testing ...@@ -7,11 +7,10 @@ import tilelang.testing
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_copy(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16") ...@@ -27,10 +26,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
program, program,
out_idx=[1], out_idx=[1],
target="cuda", target="cuda",
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, )
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
...@@ -43,11 +40,10 @@ def test_tilelang_copy(): ...@@ -43,11 +40,10 @@ def test_tilelang_copy():
def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.StridedTensor((M, N), (NN, 1), dtype), A: T.StridedTensor((M, N), (NN, 1), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): ...@@ -57,12 +53,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_copy_with_stride(M=1024, def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype="float16"):
N=1024,
NN=2048,
block_M=128,
block_N=128,
dtype="float16"):
if isinstance(NN, int): if isinstance(NN, int):
assert NN > N, "NN must be greater than N" assert NN > N, "NN must be greater than N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype) program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
...@@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024, ...@@ -73,7 +64,8 @@ def run_tilelang_copy_with_stride(M=1024,
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}) },
)
if isinstance(NN, T.Var): if isinstance(NN, T.Var):
NN = N * 2 NN = N * 2
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
...@@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride(): ...@@ -87,11 +79,10 @@ def test_tilelang_copy_with_stride():
def tilelang_copy_bufferload(num_tokens, dtype="float16"): def tilelang_copy_bufferload(num_tokens, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
indices: T.Tensor((num_tokens,), "int32"), indices: T.Tensor((num_tokens,), "int32"),
x: T.Tensor((num_tokens,), dtype), x: T.Tensor((num_tokens,), dtype),
): ):
with T.Kernel(num_tokens, threads=32) as pid: with T.Kernel(num_tokens, threads=32) as pid:
idx = T.alloc_local([1], "int32") idx = T.alloc_local([1], "int32")
...@@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"): ...@@ -107,10 +98,8 @@ def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"):
tilelang.compile( tilelang.compile(
program, program,
out_idx=[1], out_idx=[1],
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, )
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
def test_tilelang_copy_bufferload(): def test_tilelang_copy_bufferload():
...@@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload(): ...@@ -118,11 +107,10 @@ def test_tilelang_copy_bufferload():
def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float ...@@ -132,20 +120,14 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float
return main return main
def run_tilelang_copy_buffer_load_with_parallel(M=1024, def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
N=1024,
block_M=128,
block_N=128,
dtype="float16"):
program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype) program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program,
out_idx=[1], out_idx=[1],
target="cuda", target="cuda",
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, )
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......
...@@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3 ...@@ -9,8 +9,8 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3
@T.prim_func @T.prim_func
def cumsum( def cumsum(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
...@@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl ...@@ -28,8 +28,8 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl
@T.prim_func @T.prim_func
def cumsum( def cumsum(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
...@@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc ...@@ -57,13 +57,16 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
ref_b = torch.empty_like(A) ref_b = torch.empty_like(A)
for i in range(M // block_M): for i in range(M // block_M):
for j in range(N // block_N): for j in range(N // block_N):
ref_b[i * block_M:(i + 1) * block_M, ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = A[
j * block_N:(j + 1) * block_N] = A[i * block_M:(i + 1) * block_M, j * i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N
block_N:(j + 1) * block_N].cumsum(dim=dim) ].cumsum(dim=dim)
if reverse: if reverse:
ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * ref_b[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = (
block_N] = A[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * A[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N]
block_N].flip(dims=[dim]).cumsum(dim=dim).flip(dims=[dim]) .flip(dims=[dim])
.cumsum(dim=dim)
.flip(dims=[dim])
)
return ref_b return ref_b
tilelang_res = jit_kernel(A) tilelang_res = jit_kernel(A)
...@@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"): ...@@ -76,8 +79,8 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
@T.prim_func @T.prim_func
def cumsum( def cumsum(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype) A_shared = T.alloc_shared((block_N,), dtype)
...@@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"): ...@@ -94,8 +97,8 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
@T.prim_func @T.prim_func
def cumsum( def cumsum(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype) A_shared = T.alloc_shared((block_N,), dtype)
......
...@@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var ...@@ -8,7 +8,6 @@ from tvm.tir.expr import IntImm, Var
def test_argument(): def test_argument():
@T.prim_func @T.prim_func
def test_argument( def test_argument(
t_1: T.bool, t_1: T.bool,
...@@ -41,6 +40,7 @@ def test_argument(): ...@@ -41,6 +40,7 @@ def test_argument():
def test_expr(): def test_expr():
from tilelang.language.v2.dtypes import _all_dtypes from tilelang.language.v2.dtypes import _all_dtypes
errors = [] errors = []
for name in _all_dtypes: for name in _all_dtypes:
dtype = getattr(T, name) dtype = getattr(T, name)
...@@ -116,33 +116,32 @@ def test_expr(): ...@@ -116,33 +116,32 @@ def test_expr():
def test_dtype_str_repr(): def test_dtype_str_repr():
@T.prim_func @T.prim_func
def test_str_repr(): def test_str_repr():
buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope='shared') # noqa F841 buf_1 = T.alloc_buffer((1,), dtype=T.bool, scope="shared") # noqa F841
buf_2 = T.alloc_buffer((1,), dtype=T.short, scope='shared') # noqa F841 buf_2 = T.alloc_buffer((1,), dtype=T.short, scope="shared") # noqa F841
buf_3 = T.alloc_buffer((1,), dtype=T.int, scope='shared') # noqa F841 buf_3 = T.alloc_buffer((1,), dtype=T.int, scope="shared") # noqa F841
buf_4 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 buf_4 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841
buf_5 = T.alloc_buffer((1,), dtype=T.half, scope='shared') # noqa F841 buf_5 = T.alloc_buffer((1,), dtype=T.half, scope="shared") # noqa F841
buf_6 = T.alloc_buffer((1,), dtype=T.float, scope='shared') # noqa F841 buf_6 = T.alloc_buffer((1,), dtype=T.float, scope="shared") # noqa F841
buf_7 = T.alloc_buffer((1,), dtype=T.long, scope='shared') # noqa F841 buf_7 = T.alloc_buffer((1,), dtype=T.long, scope="shared") # noqa F841
buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope='shared') # noqa F841 buf_8 = T.alloc_buffer((1,), dtype=T.int8, scope="shared") # noqa F841
buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope='shared') # noqa F841 buf_9 = T.alloc_buffer((1,), dtype=T.int16, scope="shared") # noqa F841
buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope='shared') # noqa F841 buf_10 = T.alloc_buffer((1,), dtype=T.int32, scope="shared") # noqa F841
buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope='shared') # noqa F841 buf_11 = T.alloc_buffer((1,), dtype=T.int64, scope="shared") # noqa F841
buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope='shared') # noqa F841 buf_12 = T.alloc_buffer((1,), dtype=T.uint8, scope="shared") # noqa F841
buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope='shared') # noqa F841 buf_13 = T.alloc_buffer((1,), dtype=T.uint16, scope="shared") # noqa F841
buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope='shared') # noqa F841 buf_14 = T.alloc_buffer((1,), dtype=T.uint32, scope="shared") # noqa F841
buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope='shared') # noqa F841 buf_15 = T.alloc_buffer((1,), dtype=T.uint64, scope="shared") # noqa F841
buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope='shared') # noqa F841 buf_16 = T.alloc_buffer((1,), dtype=T.float8_e4m3fn, scope="shared") # noqa F841
buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope='shared') # noqa F841 buf_17 = T.alloc_buffer((1,), dtype=T.float8_e4m3fnuz, scope="shared") # noqa F841
buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope='shared') # noqa F841 buf_18 = T.alloc_buffer((1,), dtype=T.float8_e5m2, scope="shared") # noqa F841
buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope='shared') # noqa F841 buf_19 = T.alloc_buffer((1,), dtype=T.float8_e5m2fnuz, scope="shared") # noqa F841
buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope='shared') # noqa F841 buf_20 = T.alloc_buffer((1,), dtype=T.float8_e8m0fnu, scope="shared") # noqa F841
buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope='shared') # noqa F841 buf_21 = T.alloc_buffer((1,), dtype=T.float16, scope="shared") # noqa F841
buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope='shared') # noqa F841 buf_22 = T.alloc_buffer((1,), dtype=T.bfloat16, scope="shared") # noqa F841
buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope='shared') # noqa F841 buf_23 = T.alloc_buffer((1,), dtype=T.float32, scope="shared") # noqa F841
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope="shared") # noqa F841
# not supported now # not supported now
...@@ -205,7 +204,6 @@ def test_dtype_str_repr(): ...@@ -205,7 +204,6 @@ def test_dtype_str_repr():
def test_var_assign(): def test_var_assign():
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
@T.prim_func @T.prim_func
def test_var_assign(A: T.Tensor((2,), T.int32)): def test_var_assign(A: T.Tensor((2,), T.int32)):
...@@ -223,7 +221,6 @@ def test_var_assign(): ...@@ -223,7 +221,6 @@ def test_var_assign():
def test_marco_return(): def test_marco_return():
@T.macro @T.macro
def macro_return_constant(): def macro_return_constant():
return 0 return 0
...@@ -258,11 +255,10 @@ def test_marco_return(): ...@@ -258,11 +255,10 @@ def test_marco_return():
def test_prim_func_generator(): def test_prim_func_generator():
@T.prim_func(generator=True) @T.prim_func(generator=True)
def prim_func_gen( def prim_func_gen(
A=T.Tensor((128,), T.float32), # noqa: B008 A=T.Tensor((128,), T.float32), # noqa: B008
B=T.Tensor((128,), T.float32), # noqa: B008 B=T.Tensor((128,), T.float32), # noqa: B008
): ):
with T.Kernel(128) as (tx,): with T.Kernel(128) as (tx,):
T.copy(A[tx], B[tx]) T.copy(A[tx], B[tx])
...@@ -277,7 +273,6 @@ def test_prim_func_generator(): ...@@ -277,7 +273,6 @@ def test_prim_func_generator():
def test_serial_for_with_step(): def test_serial_for_with_step():
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
@T.prim_func @T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)): def test_stepped_serial(A: T.Tensor((10,), T.int32)):
...@@ -291,7 +286,7 @@ def test_serial_for_with_step(): ...@@ -291,7 +286,7 @@ def test_serial_for_with_step():
ker = test_stepped_serial() ker = test_stepped_serial()
res = ker() res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda') ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device="cuda")
assert torch.all(res == ref), f"Expected {ref}, but got {res}" assert torch.all(res == ref), f"Expected {ref}, but got {res}"
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
...@@ -304,17 +299,16 @@ def test_serial_for_with_step(): ...@@ -304,17 +299,16 @@ def test_serial_for_with_step():
ker = test_serial_step_neg() ker = test_serial_step_neg()
res = ker() res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda') ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device="cuda")
assert torch.all(res == ref), f"Expected {ref}, but got {res}" assert torch.all(res == ref), f"Expected {ref}, but got {res}"
assert isinstance(T.serial(1, 10, 1), IRBuilderFrame) assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame) assert isinstance(T.serial(1, 10, IntImm("int32", 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame) assert not isinstance(T.serial(1, 10, Var("tmp", "int32")), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame) assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)
def test_swap_logic(): def test_swap_logic():
@tilelang.jit @tilelang.jit
@T.prim_func @T.prim_func
def swap_var(A: T.Tensor[(2,), T.float32]): def swap_var(A: T.Tensor[(2,), T.float32]):
...@@ -344,7 +338,6 @@ def test_swap_logic(): ...@@ -344,7 +338,6 @@ def test_swap_logic():
def test_while_loop(): def test_while_loop():
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
@T.prim_func @T.prim_func
def test_while_loop(A: T.Tensor((1,), T.int32)): def test_while_loop(A: T.Tensor((1,), T.int32)):
...@@ -374,7 +367,7 @@ def test_var_macro(): ...@@ -374,7 +367,7 @@ def test_var_macro():
x = T.alloc_var(T.int32) x = T.alloc_var(T.int32)
macro_with_var(x) macro_with_var(x)
assert 'x[0] = 1' in prim_call_macro.script() assert "x[0] = 1" in prim_call_macro.script()
finally: finally:
pass pass
...@@ -406,7 +399,7 @@ def test_var_macro(): ...@@ -406,7 +399,7 @@ def test_var_macro():
x = T.alloc_var(T.int32) x = T.alloc_var(T.int32)
macro_with_var(x) macro_with_var(x)
assert 'x[0] = 1' in prim_call_macro.script() assert "x[0] = 1" in prim_call_macro.script()
finally: finally:
pass pass
...@@ -428,10 +421,8 @@ def test_var_macro(): ...@@ -428,10 +421,8 @@ def test_var_macro():
def test_frame_inside_macro(): def test_frame_inside_macro():
@tilelang.jit @tilelang.jit
def get_sample_kernel(): def get_sample_kernel():
@T.macro @T.macro
def transform(x): def transform(x):
return x + 1 return x + 1
...@@ -442,7 +433,7 @@ def test_frame_inside_macro(): ...@@ -442,7 +433,7 @@ def test_frame_inside_macro():
idx_out: T.Tensor[(32,), T.int32], idx_out: T.Tensor[(32,), T.int32],
): ):
with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841
fragment = T.alloc_fragment(32, 'int32') fragment = T.alloc_fragment(32, "int32")
T.copy(idx_out, fragment) T.copy(idx_out, fragment)
for i in T.Parallel(32): for i in T.Parallel(32):
...@@ -467,10 +458,10 @@ def test_buffer_slice_step(): ...@@ -467,10 +458,10 @@ def test_buffer_slice_step():
def test_boolop(): def test_boolop():
a = Var('a', 'int32') a = Var("a", "int32")
b = Var('b', 'int32') b = Var("b", "int32")
c = Var('c', 'int32') c = Var("c", "int32")
d = Var('d', 'int32') d = Var("d", "int32")
@T.macro @T.macro
def cond(): def cond():
...@@ -479,5 +470,5 @@ def test_boolop(): ...@@ -479,5 +470,5 @@ def test_boolop():
cond() cond()
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: ...@@ -23,7 +23,6 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func @T.prim_func
def laneid_kernel(A: T.Tensor((num_threads,), "int32")): def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
...@@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): ...@@ -35,7 +34,6 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func @T.prim_func
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
...@@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = ...@@ -47,7 +45,6 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] =
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func @T.prim_func
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
...@@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel( ...@@ -63,7 +60,6 @@ def _get_warp_group_idx_kernel(
warp_size: Optional[int] = None, warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None, warps_per_group: Optional[int] = None,
): ):
@T.prim_func @T.prim_func
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
...@@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel( ...@@ -75,7 +71,6 @@ def _get_warp_group_idx_kernel(
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):
@T.prim_func @T.prim_func
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
......
...@@ -4,13 +4,14 @@ import torch ...@@ -4,13 +4,14 @@ import torch
import tilelang.testing import tilelang.testing
@tilelang.jit(out_idx=[1],) @tilelang.jit(
out_idx=[1],
)
def tilelang_if_range(M, N, block_M, block_N, dtype="float16"): def tilelang_if_range(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
...@@ -5,7 +5,6 @@ import tilelang.language as T ...@@ -5,7 +5,6 @@ import tilelang.language as T
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
def get_inf_kernel(dtype: str): def get_inf_kernel(dtype: str):
@T.prim_func @T.prim_func
def main(A: T.Tensor((32,), dtype)): def main(A: T.Tensor((32,), dtype)):
with T.Kernel(1, threads=32): with T.Kernel(1, threads=32):
...@@ -18,7 +17,7 @@ def _test_infinity(dtype: str): ...@@ -18,7 +17,7 @@ def _test_infinity(dtype: str):
kernel = get_inf_kernel(dtype) kernel = get_inf_kernel(dtype)
output = kernel() output = kernel()
assert torch.all(output == torch.inf), f'check failed for {dtype=}' assert torch.all(output == torch.inf), f"check failed for {dtype=}"
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
......
...@@ -9,8 +9,8 @@ def test_language_ldg_codegen(): ...@@ -9,8 +9,8 @@ def test_language_ldg_codegen():
@T.prim_func @T.prim_func
def main( def main(
x: T.Tensor((N,), "float32"), x: T.Tensor((N,), "float32"),
y: T.Tensor((N,), "float32"), y: T.Tensor((N,), "float32"),
): ):
with T.Kernel(N, threads=32) as pid: with T.Kernel(N, threads=32) as pid:
# Explicitly request read-only cache load for x[pid] # Explicitly request read-only cache load for x[pid]
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
def _gemm_impl(): def _gemm_impl():
@T.macro @T.macro
def gemm_impl( def gemm_impl(
A: T.Tensor[[int, int], Any], A: T.Tensor[[int, int], Any],
...@@ -37,7 +36,6 @@ def _gemm_impl(): ...@@ -37,7 +36,6 @@ def _gemm_impl():
def test_jit2_gemm_annot(): def test_jit2_gemm_annot():
@tilelang.lazy_jit @tilelang.lazy_jit
def gemm( def gemm(
A: T.Tensor[[int, int], Any], A: T.Tensor[[int, int], Any],
...@@ -54,24 +52,24 @@ def test_jit2_gemm_annot(): ...@@ -54,24 +52,24 @@ def test_jit2_gemm_annot():
return C return C
prod = product([T.float16, T.float32], [T.float32]) prod = product([T.float16, T.float32], [T.float32])
gemm.par_compile([{ gemm.par_compile(
'A': T.Tensor((1024, 1024), dtype=in_dtype), [
'B': T.Tensor((1024, 1024), dtype=in_dtype), {"A": T.Tensor((1024, 1024), dtype=in_dtype), "B": T.Tensor((1024, 1024), dtype=in_dtype), "out_dtype": out_dtype}
'out_dtype': out_dtype for in_dtype, out_dtype in prod
} for in_dtype, out_dtype in prod]) ]
)
for in_dtype, out_dtype in prod: for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch() in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch() out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
C_ref = out_dtype(A @ B) C_ref = out_dtype(A @ B)
C = gemm(A, B) C = gemm(A, B)
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)
def test_jit2_gemm_ptr(): def test_jit2_gemm_ptr():
@tilelang.lazy_jit @tilelang.lazy_jit
def gemm_ptr( def gemm_ptr(
A: T.ptr, A: T.ptr,
...@@ -92,23 +90,19 @@ def test_jit2_gemm_ptr(): ...@@ -92,23 +90,19 @@ def test_jit2_gemm_ptr():
_gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K) _gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K)
prod = product([T.float16, T.float32], [T.float32]) prod = product([T.float16, T.float32], [T.float32])
gemm_ptr.par_compile([{ gemm_ptr.par_compile(
'A': T.ptr(), [
'B': T.ptr(), {"A": T.ptr(), "B": T.ptr(), "C": T.ptr(), "M": 1024, "N": 1024, "K": 1024, "dtype": in_dtype, "out_dtype": out_dtype}
'C': T.ptr(), for in_dtype, out_dtype in prod
'M': 1024, ]
'N': 1024, )
'K': 1024,
'dtype': in_dtype,
'out_dtype': out_dtype
} for in_dtype, out_dtype in prod])
for in_dtype, out_dtype in prod: for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch() in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch() out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda') B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
C_ref = out_dtype(A @ B) C_ref = out_dtype(A @ B)
C = torch.empty(1024, 1024, dtype=out_dtype, device='cuda') C = torch.empty(1024, 1024, dtype=out_dtype, device="cuda")
gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype) gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype)
torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2) torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2)
...@@ -129,8 +123,7 @@ def test_jit2_annot(): ...@@ -129,8 +123,7 @@ def test_jit2_annot():
AnnotTest( AnnotTest(
annot=T.Tensor[[int, int], T.float32], annot=T.Tensor[[int, int], T.float32],
promote=False, promote=False,
match_ok=[torch.randn(1, 1, dtype=torch.float32), match_ok=[torch.randn(1, 1, dtype=torch.float32), T.Tensor((1, 1), dtype=T.float32)],
T.Tensor((1, 1), dtype=T.float32)],
match_ng=[ match_ng=[
torch.randn(1, 1, dtype=torch.float16), torch.randn(1, 1, dtype=torch.float16),
T.Tensor(1, dtype=T.float32), T.Tensor(1, dtype=T.float32),
...@@ -146,8 +139,8 @@ def test_jit2_annot(): ...@@ -146,8 +139,8 @@ def test_jit2_annot():
T.Tensor((1,), dtype=T.float32), T.Tensor((1,), dtype=T.float32),
T.Tensor((1,), dtype=T.float16), T.Tensor((1,), dtype=T.float16),
], ],
match_ng=[torch.randn((1, 1), dtype=torch.float32), match_ng=[torch.randn((1, 1), dtype=torch.float32), T.Tensor((1, 1), dtype=T.float16)],
T.Tensor((1, 1), dtype=T.float16)]), ),
AnnotTest( AnnotTest(
annot=T.Tensor[[int, 1], Any], annot=T.Tensor[[int, 1], Any],
promote=False, promote=False,
...@@ -157,8 +150,8 @@ def test_jit2_annot(): ...@@ -157,8 +150,8 @@ def test_jit2_annot():
T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16), T.Tensor((12, 1), T.float16),
], ],
match_ng=[torch.randn(12, 12, dtype=torch.float32), match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)],
T.Tensor((12, 12), T.float32)]), ),
AnnotTest( AnnotTest(
annot=T.Tensor[[T.dyn, 1], Any], annot=T.Tensor[[T.dyn, 1], Any],
promote=False, promote=False,
...@@ -168,43 +161,39 @@ def test_jit2_annot(): ...@@ -168,43 +161,39 @@ def test_jit2_annot():
T.Tensor((12, 1), T.float32), T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16), T.Tensor((12, 1), T.float16),
], ],
match_ng=[torch.randn(12, 12, dtype=torch.float32), match_ng=[torch.randn(12, 12, dtype=torch.float32), T.Tensor((12, 12), T.float32)],
T.Tensor((12, 12), T.float32)]), ),
AnnotTest( AnnotTest(
annot=T.Tensor[[1024, 1024], T.float32], annot=T.Tensor[[1024, 1024], T.float32],
promote=True, promote=True,
), ),
AnnotTest(annot=T.dyn[int, 'X'], promote=False, match_ok=[1, 2, 3, 4]), AnnotTest(annot=T.dyn[int, "X"], promote=False, match_ok=[1, 2, 3, 4]),
AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]) AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4]),
] ]
for test in tests: for test in tests:
promote = test.annot.promote() promote = test.annot.promote()
promoted = promote is not None promoted = promote is not None
if promoted != test.promote: if promoted != test.promote:
raise AssertionError( raise AssertionError(f"Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}")
f'Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}') with Builder().prim_func("_test"):
with Builder().prim_func('_test'):
for match_ok in test.match_ok: for match_ok in test.match_ok:
try: try:
vt = ArgVarTable() vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ok, vt) test.annot.create_prim_func_arg("arg", match_ok, vt)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise AssertionError( raise AssertionError(f"Match failed for {test.annot} with value {match_ok}: {e}") from e
f'Match failed for {test.annot} with value {match_ok}: {e}') from e
for match_ng in test.match_ng: for match_ng in test.match_ng:
try: try:
vt = ArgVarTable() vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ng, vt) test.annot.create_prim_func_arg("arg", match_ng, vt)
raise AssertionError( raise AssertionError(f"Match unexpectedly succeeded for {test.annot} with value {match_ng}")
f'Match unexpectedly succeeded for {test.annot} with value {match_ng}')
except Exception: except Exception:
pass pass
def test_jit2_many_annot(): def test_jit2_many_annot():
@T.macro @T.macro
def copy_impl(A, B): def copy_impl(A, B):
M, N = A.shape M, N = A.shape
...@@ -213,8 +202,7 @@ def test_jit2_many_annot(): ...@@ -213,8 +202,7 @@ def test_jit2_many_annot():
assert N == N_, f"N mismatch {N} {N_}" assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128])
by * 128:by * 128 + 128])
@tilelang.lazy_jit @tilelang.lazy_jit
def copy1( def copy1(
...@@ -259,20 +247,19 @@ def test_jit2_many_annot(): ...@@ -259,20 +247,19 @@ def test_jit2_many_annot():
copy_impl(A, B) copy_impl(A, B)
for copy in [copy1, copy2, copy3, copy4]: for copy in [copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda') A = torch.randn(128, 128, device="cuda")
B = torch.empty(128, 128, device='cuda') B = torch.empty(128, 128, device="cuda")
copy(A, B) copy(A, B)
assert torch.equal(B, A) assert torch.equal(B, A)
for copy in [copy5, copy6]: for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda') A = torch.randn(128, 2, 128, 2, device="cuda")
B = torch.randn(128, 2, 128, 2, device='cuda') B = torch.randn(128, 2, 128, 2, device="cuda")
copy(A[:, 0, :, 0], B[:, 0, :, 0]) copy(A[:, 0, :, 0], B[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0]) assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0])
def test_jit2_return(): def test_jit2_return():
@T.macro @T.macro
def copy_impl(A): def copy_impl(A):
M, N = A.shape M, N = A.shape
...@@ -283,8 +270,7 @@ def test_jit2_return(): ...@@ -283,8 +270,7 @@ def test_jit2_return():
assert N == N_, f"N mismatch {N} {N_}" assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}" # assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by): with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128, T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128])
by * 128:by * 128 + 128])
return B return B
@tilelang.lazy_jit @tilelang.lazy_jit
...@@ -292,41 +278,52 @@ def test_jit2_return(): ...@@ -292,41 +278,52 @@ def test_jit2_return():
return copy_impl(A) return copy_impl(A)
@tilelang.lazy_jit @tilelang.lazy_jit
def copy1(A: T.Tensor[[int, int], T.float32],): def copy1(
A: T.Tensor[[int, int], T.float32],
):
return copy_impl(A) return copy_impl(A)
@tilelang.lazy_jit @tilelang.lazy_jit
def copy2(A: T.Tensor[[128, 128], T.float32],): def copy2(
A: T.Tensor[[128, 128], T.float32],
):
return copy_impl(A) return copy_impl(A)
@tilelang.lazy_jit @tilelang.lazy_jit
def copy3(A: T.Tensor[[int, 128], T.float32],): def copy3(
A: T.Tensor[[int, 128], T.float32],
):
return copy_impl(A) return copy_impl(A)
@tilelang.lazy_jit @tilelang.lazy_jit
def copy4(A: T.Tensor[[T.dyn, int], T.float32],): def copy4(
A: T.Tensor[[T.dyn, int], T.float32],
):
return copy_impl(A) return copy_impl(A)
@tilelang.lazy_jit @tilelang.lazy_jit
def copy5(A: T.StridedTensor[[int, int], [int, int], T.float32],): def copy5(
A: T.StridedTensor[[int, int], [int, int], T.float32],
):
return copy_impl(A) return copy_impl(A)
@tilelang.lazy_jit @tilelang.lazy_jit
def copy6(A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],): def copy6(
A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],
):
return copy_impl(A) return copy_impl(A)
for copy in [copy0, copy1, copy2, copy3, copy4]: for copy in [copy0, copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda') A = torch.randn(128, 128, device="cuda")
B = copy(A) B = copy(A)
assert torch.equal(B, A) assert torch.equal(B, A)
for copy in [copy5, copy6]: for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda') A = torch.randn(128, 2, 128, 2, device="cuda")
B = copy(A[:, 0, :, 0]) B = copy(A[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B) assert torch.equal(A[:, 0, :, 0], B)
def test_jit2_deepseek_deepgemm(): def test_jit2_deepseek_deepgemm():
@tilelang.lazy_jit @tilelang.lazy_jit
def deep_gemm( def deep_gemm(
A: T.Tensor[[int, int], T.float8_e4m3], A: T.Tensor[[int, int], T.float8_e4m3],
...@@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm(): ...@@ -351,13 +348,9 @@ def test_jit2_deepseek_deepgemm():
N, K = B.shape N, K = B.shape
C = T.empty(M, N, dtype=out_dtype) C = T.empty(M, N, dtype=out_dtype)
assert out_dtype in [ assert out_dtype in [T.bfloat16, T.float32], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}"
T.bfloat16, T.float32 assert scales_a.shape == [M, T.ceildiv(K, group_size)], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}"
], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}" assert scales_b.shape == [N, T.ceildiv(K, group_size)], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}"
assert scales_a.shape == [M, T.ceildiv(K, group_size)
], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}"
assert scales_b.shape == [N, T.ceildiv(K, group_size)
], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}"
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), in_dtype) A_shared = T.alloc_shared((block_M, block_K), in_dtype)
...@@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm(): ...@@ -421,5 +414,5 @@ def test_jit2_deepseek_deepgemm():
# M, N, K = 1024, 1024, 8192 # M, N, K = 1024, 1024, 8192
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, ) # A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -4,7 +4,6 @@ from tilelang import language as T ...@@ -4,7 +4,6 @@ from tilelang import language as T
def test_let_vectorize_load(): def test_let_vectorize_load():
@T.prim_func @T.prim_func
def main(A_ptr: T.handle): def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
......
...@@ -6,11 +6,10 @@ import torch ...@@ -6,11 +6,10 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
...@@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): ...@@ -30,13 +29,8 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
out_idx=[1], )
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
...@@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel(): ...@@ -49,11 +43,10 @@ def test_tilelang_copy_mask_parallel():
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
...@@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): ...@@ -72,13 +65,8 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
out_idx=[1], )
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
...@@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy(): ...@@ -91,11 +79,10 @@ def test_tilelang_copy_mask_copy():
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
...@@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): ...@@ -112,20 +99,11 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_copy_mask_parallel_range(M=1024, def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
N=1024,
block_M=128,
block_N=128,
dtype="float16"):
program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
out_idx=[1], )
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
...@@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range(): ...@@ -138,11 +116,10 @@ def test_tilelang_copy_mask_parallel_range():
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
...@@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): ...@@ -161,13 +138,8 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
out_idx=[1], )
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment