Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch ...@@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch
from typing import List from typing import List
def run_general_reduction_recommend_hints(structure: str = "SSR", def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: str = "float16", topk: int = 20):
shape: List[int] = None,
dtype: str = "float16",
topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.GeneralReductionTemplate( carve_template = carver.GeneralReductionTemplate(
structure=structure, structure=structure,
...@@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints(): ...@@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16") run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16")
def run_elementwise_recommend_hints(shape: List[int] = None, def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float16", topk: int = 20):
dtype: str = "float16",
topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.ElementwiseTemplate( carve_template = carver.ElementwiseTemplate(
shape=shape, shape=shape,
...@@ -81,11 +76,9 @@ def test_matmul_recommend_hints(): ...@@ -81,11 +76,9 @@ def test_matmul_recommend_hints():
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16") run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16")
def run_gemv_recommend_hints(N: int = 1024, def run_gemv_recommend_hints(
K: int = 1024, N: int = 1024, K: int = 1024, in_dtype: str = "float16", out_dtype: str = "float16", accum_dtype: str = "float16"
in_dtype: str = "float16", ):
out_dtype: str = "float16",
accum_dtype: str = "float16"):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.GEMVTemplate( carve_template = carver.GEMVTemplate(
N=N, N=N,
......
...@@ -23,7 +23,8 @@ def _compile_kernel_without_inplace(): ...@@ -23,7 +23,8 @@ def _compile_kernel_without_inplace():
@tilelang.jit( @tilelang.jit(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True, tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True,
},) },
)
def _compile_kernel_with_inplace(): def _compile_kernel_with_inplace():
num_tokens = T.symbolic("num_tokens") num_tokens = T.symbolic("num_tokens")
......
...@@ -26,9 +26,9 @@ def matmul( ...@@ -26,9 +26,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -88,7 +88,8 @@ def run_gemm( ...@@ -88,7 +88,8 @@ def run_gemm(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized,
}) },
)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
......
...@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func @T.prim_func
def matmul( def matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype) A_local = T.alloc_local((block_M, block_K), dtype)
...@@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# ) # )
for ko in T.Pipelined(K // block_K, num_stages=num_stages): for ko in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, ko * block_K], A_local) T.copy(A[by * block_M, ko * block_K], A_local)
# Or Copy with Parallel # Or Copy with Parallel
...@@ -62,14 +61,13 @@ def test_matmul_codegen(): ...@@ -62,14 +61,13 @@ def test_matmul_codegen():
def test_matmul_compile(): def test_matmul_compile():
def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# a simple kernel just for jit test # a simple kernel just for jit test
@T.prim_func @T.prim_func
def matmul( def matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype) A_local = T.alloc_local((block_M, block_K), dtype)
......
...@@ -7,7 +7,6 @@ import tilelang.language as T ...@@ -7,7 +7,6 @@ import tilelang.language as T
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI # TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# Please run manually when you want to verify that device_assert actually traps on GPU. # Please run manually when you want to verify that device_assert actually traps on GPU.
def _manual_device_assert_triggered(): def _manual_device_assert_triggered():
@T.prim_func @T.prim_func
def program(): def program():
with T.Kernel(threads=128): with T.Kernel(threads=128):
...@@ -20,7 +19,6 @@ def _manual_device_assert_triggered(): ...@@ -20,7 +19,6 @@ def _manual_device_assert_triggered():
def test_device_assert_no_trigger(): def test_device_assert_no_trigger():
@T.prim_func @T.prim_func
def program(): def program():
with T.Kernel(threads=128): with T.Kernel(threads=128):
......
...@@ -6,7 +6,6 @@ import tilelang.language as T ...@@ -6,7 +6,6 @@ import tilelang.language as T
def debug_print_buffer(M=16, N=16, dtype="float16"): def debug_print_buffer(M=16, N=16, dtype="float16"):
@T.prim_func @T.prim_func
def program(Q: T.Tensor((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
...@@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): ...@@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def test_debug_print_buffer(): def test_debug_print_buffer():
debug_print_buffer(dtype='bool') debug_print_buffer(dtype="bool")
debug_print_buffer(dtype='int8') debug_print_buffer(dtype="int8")
debug_print_buffer(dtype='int16') debug_print_buffer(dtype="int16")
debug_print_buffer(dtype='int32') debug_print_buffer(dtype="int32")
debug_print_buffer(dtype='int64') debug_print_buffer(dtype="int64")
debug_print_buffer(dtype='uint8') debug_print_buffer(dtype="uint8")
debug_print_buffer(dtype='uint16') debug_print_buffer(dtype="uint16")
debug_print_buffer(dtype='uint32') debug_print_buffer(dtype="uint32")
debug_print_buffer(dtype='uint64') debug_print_buffer(dtype="uint64")
debug_print_buffer(dtype='float16') debug_print_buffer(dtype="float16")
debug_print_buffer(dtype='float32') debug_print_buffer(dtype="float32")
debug_print_buffer(dtype='float64') debug_print_buffer(dtype="float64")
debug_print_buffer(dtype='bfloat16') debug_print_buffer(dtype="bfloat16")
debug_print_buffer(dtype='float8_e4m3') debug_print_buffer(dtype="float8_e4m3")
debug_print_buffer(dtype='float8_e4m3fn') debug_print_buffer(dtype="float8_e4m3fn")
debug_print_buffer(dtype='float8_e4m3fnuz') debug_print_buffer(dtype="float8_e4m3fnuz")
debug_print_buffer(dtype='float8_e5m2') debug_print_buffer(dtype="float8_e5m2")
debug_print_buffer(dtype='float8_e5m2fnuz') debug_print_buffer(dtype="float8_e5m2fnuz")
def debug_print_buffer_conditional(M=16, N=16): def debug_print_buffer_conditional(M=16, N=16):
......
...@@ -5,7 +5,7 @@ import tilelang.testing ...@@ -5,7 +5,7 @@ import tilelang.testing
from tvm import DataType from tvm import DataType
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter) from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -96,12 +96,11 @@ def tl_matmul_macro( ...@@ -96,12 +96,11 @@ def tl_matmul_macro(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -109,10 +108,12 @@ def tl_matmul_macro( ...@@ -109,10 +108,12 @@ def tl_matmul_macro(
B_local = T.alloc_local((warp_cols * local_size), in_dtype) B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -120,7 +121,6 @@ def tl_matmul_macro( ...@@ -120,7 +121,6 @@ def tl_matmul_macro(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -130,7 +130,6 @@ def tl_matmul_macro( ...@@ -130,7 +130,6 @@ def tl_matmul_macro(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
...@@ -207,8 +206,7 @@ def tl_matmul_block( ...@@ -207,8 +206,7 @@ def tl_matmul_block(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
...@@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic( ...@@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
...@@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( ...@@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
) )
pass_configs = { pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0, tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0,
tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment,
} }
if M % 64 == 0 or N % 64 == 0 or K % 64 != 0: if M % 64 == 0 or N % 64 == 0 or K % 64 != 0:
# workaround for hopper tma lower pass # workaround for hopper tma lower pass
...@@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro(): ...@@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro():
def test_assert_tl_matmul_block(): def test_assert_tl_matmul_block():
assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
64, 64, 32) assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic(): def test_assert_tl_matmul_block_all_dynamic():
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
"float16", "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
128, 128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8
128, )
128,
False,
False,
"float16",
"float16",
"float16",
64,
64,
32,
dynamic_alignment=8)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8
128, )
128,
False,
False,
"float16",
"float16",
"float16",
64,
64,
32,
dynamic_alignment=8)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4) 64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4
)
# Tail split is enabled with dynamic alignment 0 # Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0) 64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -25,10 +25,8 @@ def tl_matmul_block_static( ...@@ -25,10 +25,8 @@ def tl_matmul_block_static(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m( ...@@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn( ...@@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk( ...@@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk( ...@@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk(
def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K): def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", "float16", "float32")
"float16", "float32")
def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
...@@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): ...@@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16", "float16",
"float16", "float16",
"float32", "float32",
pass_configs={ pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
"tl.disable_dynamic_tail_split": True, )
"tl.dynamic_alignment": 8
})
assert_tl_matmul_block_dynamic_m( assert_tl_matmul_block_dynamic_m(
M, M,
N, N,
...@@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): ...@@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16", "float16",
"float16", "float16",
"float32", "float32",
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False},
)
def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
...@@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): ...@@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16", "float16",
"float16", "float16",
"float32", "float32",
pass_configs={ pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
"tl.disable_dynamic_tail_split": True, )
"tl.dynamic_alignment": 8
})
assert_tl_matmul_block_dynamic_mn( assert_tl_matmul_block_dynamic_mn(
M, M,
N, N,
...@@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): ...@@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16", "float16",
"float16", "float16",
"float32", "float32",
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False},
)
def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
...@@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16", "float16",
"float16", "float16",
"float32", "float32",
pass_configs={ pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4},
"tl.disable_dynamic_tail_split": True, )
"tl.dynamic_alignment": 4
})
assert_tl_matmul_block_dynamic_mnk( assert_tl_matmul_block_dynamic_mnk(
M, M,
N, N,
...@@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16", "float16",
"float16", "float16",
"float32", "float32",
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False},
)
def test_all(): def test_all():
......
...@@ -7,16 +7,16 @@ import re ...@@ -7,16 +7,16 @@ import re
def get_mathop_lines(source, mathop_name): def get_mathop_lines(source, mathop_name):
"""Extract lines containing the mathop from CUDA source for debugging""" """Extract lines containing the mathop from CUDA source for debugging"""
lines = source.split('\n') lines = source.split("\n")
relevant_lines = [] relevant_lines = []
for i, line in enumerate(lines): for i, line in enumerate(lines):
if mathop_name in line and ('(' in line): if mathop_name in line and ("(" in line):
# Include some context # Include some context
start = max(0, i - 1) start = max(0, i - 1)
end = min(len(lines), i + 2) end = min(len(lines), i + 2)
relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)]) relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)])
relevant_lines.append("---") relevant_lines.append("---")
return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
def check_fastmath_usage(source, mathop_name, expect_fastmath=False): def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
...@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False): ...@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches = re.findall(fastmath_pattern, source) fastmath_matches = re.findall(fastmath_pattern, source)
non_fastmath_matches = re.findall(non_fastmath_pattern, source) non_fastmath_matches = re.findall(non_fastmath_pattern, source)
print( print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls")
f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls"
)
if len(fastmath_matches) > 0: if len(fastmath_matches) > 0:
print(f"Fastmath calls found: {fastmath_matches}") print(f"Fastmath calls found: {fastmath_matches}")
if len(non_fastmath_matches) > 0: if len(non_fastmath_matches) > 0:
...@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): ...@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False) check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name, def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
""" """
Test single-argument mathops. Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name, ...@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
bx * block_N + j])
# Test with FAST_MATH disabled # Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile( kernel_no_fastmath = tilelang.compile(
...@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name, ...@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_no_fastmath = kernel_no_fastmath.get_kernel_source()
...@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name, ...@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print(f"✓ {mathop_name} compilation and execution test passed") print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name, def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
""" """
Test two-argument mathops to ensure they generate non-fastmath CUDA code. Test two-argument mathops to ensure they generate non-fastmath CUDA code.
""" """
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, C[by * block_M + i, bx * block_N + j] = mathop_func(
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j], A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]
B[by * block_M + i, bx * block_N + j]) )
# Test with FAST_MATH disabled # Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile( kernel_no_fastmath = tilelang.compile(
...@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name, ...@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
# Test with FAST_MATH enabled # Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile( kernel_fastmath = tilelang.compile(
...@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name, ...@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source() source_no_fastmath = kernel_no_fastmath.get_kernel_source()
source_fastmath = kernel_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source()
...@@ -171,8 +159,8 @@ def run_abs_test(): ...@@ -171,8 +159,8 @@ def run_abs_test():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32"),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -184,7 +172,8 @@ def run_abs_test(): ...@@ -184,7 +172,8 @@ def run_abs_test():
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
}) },
)
source = kernel.get_kernel_source() source = kernel.get_kernel_source()
print("\n=== Testing abs (maps to fabs) ===") print("\n=== Testing abs (maps to fabs) ===")
...@@ -199,26 +188,19 @@ def run_abs_test(): ...@@ -199,26 +188,19 @@ def run_abs_test():
print("✓ abs numerical test passed") print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name, def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
""" """
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
""" """
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
bx * block_N + j])
# Test with FAST_MATH enabled # Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile( kernel_fastmath = tilelang.compile(
...@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name, ...@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target="cuda", target="cuda",
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
source_fastmath = kernel_fastmath.get_kernel_source() source_fastmath = kernel_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} (fastmath version) ===") print(f"\n=== Testing {mathop_name} (fastmath version) ===")
print("FAST_MATH=True:") print("FAST_MATH=True:")
# Strip the __ prefix for checking in the CUDA source # Strip the __ prefix for checking in the CUDA source
cuda_mathop_name = mathop_name.lstrip('_') cuda_mathop_name = mathop_name.lstrip("_")
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness # Test numerical correctness
......
...@@ -8,14 +8,15 @@ from tilelang import language as T ...@@ -8,14 +8,15 @@ from tilelang import language as T
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},) },
)
def _cumsum_view_infer_layout(hidden): def _cumsum_view_infer_layout(hidden):
num_tokens = T.dynamic('num_tokens') num_tokens = T.dynamic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']): def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]):
with T.Kernel(num_tokens, threads=128) as pid: with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden,), dtype='float') smem = T.alloc_shared((hidden,), dtype="float")
T.copy(x[pid, :], smem) T.copy(x[pid, :], smem)
T.cumsum(T.view(smem, (1, hidden)), dim=1) T.cumsum(T.view(smem, (1, hidden)), dim=1)
...@@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden): ...@@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden):
def test_cumsum_view_infer_layout(): def test_cumsum_view_infer_layout():
hidden = 128 hidden = 128
x = torch.randn(1, hidden, device='cuda', dtype=torch.float) x = torch.randn(1, hidden, device="cuda", dtype=torch.float)
kernel = _cumsum_view_infer_layout(hidden) kernel = _cumsum_view_infer_layout(hidden)
kernel(x) kernel(x)
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -8,12 +8,13 @@ from tilelang import language as T ...@@ -8,12 +8,13 @@ from tilelang import language as T
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},) },
)
def _fill_with_static_region_kernel(): def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens') num_tokens = T.symbolic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _: with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0) T.fill(x[0:128], 0)
...@@ -24,14 +25,15 @@ def _fill_with_static_region_kernel(): ...@@ -24,14 +25,15 @@ def _fill_with_static_region_kernel():
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},) },
)
def _fill_with_dynamic_region_kernel(): def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens') num_tokens = T.symbolic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _: with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int') a, b = T.alloc_var("int"), T.alloc_var("int")
T.fill(x[a:b], 0) T.fill(x[a:b], 0)
return buggy_kernel return buggy_kernel
...@@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel(): ...@@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel():
def test_fill_with_static_region_kernel(): def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel() kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda') x = torch.zeros((256,), dtype=torch.int64, device="cuda")
kernel(x) kernel(x)
def test_fill_with_dynamic_region_kernel(): def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel() kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda') x = torch.zeros((256,), dtype=torch.int64, device="cuda")
kernel(x) kernel(x)
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -4,25 +4,23 @@ import tilelang.language as T ...@@ -4,25 +4,23 @@ import tilelang.language as T
def test_int64_address(): def test_int64_address():
@tilelang.jit @tilelang.jit
def set_cache_kernel( def set_cache_kernel(
S, S,
D, D,
pos_ty='int64', pos_ty="int64",
dtype="float32", dtype="float32",
): ):
@T.prim_func @T.prim_func
def main( def main(
pos: T pos: T.Tensor(
.Tensor(
[ [
S, S,
], pos_ty ],
pos_ty,
), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32` ), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value: T.Tensor([S, D], dtype), # type: ignore value: T.Tensor([S, D], dtype), # type: ignore
cache: T.Tensor([S, D], dtype), # type: ignore cache: T.Tensor([S, D], dtype), # type: ignore
): ):
with T.Kernel(S, threads=128) as bx: with T.Kernel(S, threads=128) as bx:
slot = pos[bx] slot = pos[bx]
...@@ -34,11 +32,11 @@ def test_int64_address(): ...@@ -34,11 +32,11 @@ def test_int64_address():
D = 2 D = 2
S = 10 S = 10
cache = torch.rand((S, D), device="cuda", dtype=torch.float32) cache = torch.rand((S, D), device="cuda", dtype=torch.float32)
value = torch.rand((S, D), device='cuda', dtype=torch.float32) value = torch.rand((S, D), device="cuda", dtype=torch.float32)
pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64) pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64)
pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32) pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, 'int64') kernel_int64 = set_cache_kernel(S, D, "int64")
kernel_int32 = set_cache_kernel(S, D, 'int32') kernel_int32 = set_cache_kernel(S, D, "int32")
kernel_int64(pos_int64, value, cache) kernel_int64(pos_int64, value, cache)
torch.testing.assert_close(cache, value) torch.testing.assert_close(cache, value)
kernel_int32(pos_int32, value, cache) kernel_int32(pos_int32, value, cache)
......
...@@ -3,13 +3,17 @@ import tilelang.language as T ...@@ -3,13 +3,17 @@ import tilelang.language as T
def test_issue_1198(): def test_issue_1198():
@T.prim_func @T.prim_func
def foo(x: T.Buffer([ def foo(
32, x: T.Buffer(
], "int32")): [
32,
],
"int32",
),
):
pass pass
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -6,11 +6,10 @@ import torch ...@@ -6,11 +6,10 @@ import torch
@tilelang.jit @tilelang.jit
def _tmp_var_kernel(N, block_N, dtype="float"): def _tmp_var_kernel(N, block_N, dtype="float"):
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:
for i in T.Parallel(block_N): for i in T.Parallel(block_N):
......
...@@ -8,7 +8,6 @@ import tilelang.language as T ...@@ -8,7 +8,6 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def _empty_kernel(): def _empty_kernel():
@T.prim_func @T.prim_func
def empty_kernel(): def empty_kernel():
with T.Kernel(1, threads=32) as thread_idx: with T.Kernel(1, threads=32) as thread_idx:
...@@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel(): ...@@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel():
@tilelang.jit @tilelang.jit
def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False): def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):
@T.prim_func @T.prim_func
def kernel_with_tuple_kernel_binding(): def kernel_with_tuple_kernel_binding():
with T.Kernel(1, threads=32) as (pid,): with T.Kernel(1, threads=32) as (pid,):
......
...@@ -5,18 +5,16 @@ import torch ...@@ -5,18 +5,16 @@ import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( bx,
bx, by,
by, ):
):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......
...@@ -6,7 +6,6 @@ import tilelang.language as T ...@@ -6,7 +6,6 @@ import tilelang.language as T
def merge_if_test(): def merge_if_test():
@T.prim_func @T.prim_func
def main(): def main():
A = T.alloc_fragment((1,), "float16") A = T.alloc_fragment((1,), "float16")
......
...@@ -29,9 +29,9 @@ def matmul( ...@@ -29,9 +29,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -141,9 +141,9 @@ def matmu_jit_kernel( ...@@ -141,9 +141,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -208,6 +208,7 @@ def run_gemm_jit_kernel( ...@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B): def ref_program(A, B):
import torch import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
......
...@@ -31,9 +31,9 @@ def matmul_kernel_jit( ...@@ -31,9 +31,9 @@ def matmul_kernel_jit(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -96,6 +96,7 @@ def run_gemm_kernel_jit( ...@@ -96,6 +96,7 @@ def run_gemm_kernel_jit(
def ref_program(A, B): def ref_program(A, B):
import torch import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
......
...@@ -28,9 +28,9 @@ def matmul( ...@@ -28,9 +28,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -138,9 +138,9 @@ def matmu_jit_kernel( ...@@ -138,9 +138,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -208,6 +208,7 @@ def run_gemm_jit_kernel( ...@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B): def ref_program(A, B):
import torch import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype) C = C.to(out_dtype)
return C return C
...@@ -235,19 +236,9 @@ def test_gemm_jit_kernel(): ...@@ -235,19 +236,9 @@ def test_gemm_jit_kernel():
) )
def run_cython_kernel_do_bench(M, def run_cython_kernel_do_bench(
N, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
K, ):
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M, ...@@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M,
def test_cython_kernel_do_bench(): def test_cython_kernel_do_bench():
run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
256, 32, 2)
def run_cython_kernel_multi_stream(
def run_cython_kernel_multi_stream(M, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
N, ):
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M, ...@@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M,
def test_cython_kernel_multi_stream(): def test_cython_kernel_multi_stream():
run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
128, 256, 32, 2)
def run_cython_dynamic_shape(
def run_cython_dynamic_shape(M, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
N, ):
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M, ...@@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M,
matmul_kernel(tensor_a, tensor_b, tensor_c) matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close( tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cython_dynamic_shape(): def test_cython_dynamic_shape():
run_cython_dynamic_shape( run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2)
256, 32, 2)
run_cython_dynamic_shape( def run_cython_dynamic_shape_with_out_idx(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
"float16", 128, 256, 32, 2) ):
def run_cython_dynamic_shape_with_out_idx(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M, ...@@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M,
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close( tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cython_dynamic_shape_with_out_idx(): def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx( run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def matmul_int_variable( def matmul_int_variable(
...@@ -498,10 +449,10 @@ def matmul_int_variable( ...@@ -498,10 +449,10 @@ def matmul_int_variable(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
offset: T.int32, offset: T.int32,
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -525,10 +476,10 @@ def matmul_int_variable( ...@@ -525,10 +476,10 @@ def matmul_int_variable(
return main return main
def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads):
out_dtype, dtypeAccum, num_stages, threads): program = matmul_int_variable(
program = matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads
out_dtype, dtypeAccum, num_stages, threads) )
matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
...@@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B ...@@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B
def test_matmul_int_variable(): def test_matmul_int_variable():
run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128)
"float32", 0, 128)
def matmul_float_variable( def matmul_float_variable(
...@@ -570,10 +520,10 @@ def matmul_float_variable( ...@@ -570,10 +520,10 @@ def matmul_float_variable(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
offset: T.float32, offset: T.float32,
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -597,10 +547,10 @@ def matmul_float_variable( ...@@ -597,10 +547,10 @@ def matmul_float_variable(
return main return main
def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads):
out_dtype, dtypeAccum, num_stages, threads): program = matmul_float_variable(
program = matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads
out_dtype, dtypeAccum, num_stages, threads) )
matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2) matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2)
in_dtype = map_torch_type(in_dtype) in_dtype = map_torch_type(in_dtype)
...@@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans ...@@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans
def test_matmul_float_variable(): def test_matmul_float_variable():
run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128)
"float32", 0, 128)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment