Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch
from typing import List
def run_general_reduction_recommend_hints(structure: str = "SSR",
shape: List[int] = None,
dtype: str = "float16",
topk: int = 20):
def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: str = "float16", topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.GeneralReductionTemplate(
structure=structure,
......@@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16")
def run_elementwise_recommend_hints(shape: List[int] = None,
dtype: str = "float16",
topk: int = 20):
def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float16", topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.ElementwiseTemplate(
shape=shape,
......@@ -81,11 +76,9 @@ def test_matmul_recommend_hints():
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16")
def run_gemv_recommend_hints(N: int = 1024,
K: int = 1024,
in_dtype: str = "float16",
out_dtype: str = "float16",
accum_dtype: str = "float16"):
def run_gemv_recommend_hints(
N: int = 1024, K: int = 1024, in_dtype: str = "float16", out_dtype: str = "float16", accum_dtype: str = "float16"
):
arch = auto_infer_current_arch()
carve_template = carver.GEMVTemplate(
N=N,
......
......@@ -23,7 +23,8 @@ def _compile_kernel_without_inplace():
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True,
},)
},
)
def _compile_kernel_with_inplace():
num_tokens = T.symbolic("num_tokens")
......
......@@ -26,9 +26,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -88,7 +88,8 @@ def run_gemm(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized,
})
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
......
......@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func
def matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype)
......@@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# )
for ko in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, ko * block_K], A_local)
# Or Copy with Parallel
......@@ -62,14 +61,13 @@ def test_matmul_codegen():
def test_matmul_compile():
def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# a simple kernel just for jit test
@T.prim_func
def matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype)
......
......@@ -7,7 +7,6 @@ import tilelang.language as T
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# Please run manually when you want to verify that device_assert actually traps on GPU.
def _manual_device_assert_triggered():
@T.prim_func
def program():
with T.Kernel(threads=128):
......@@ -20,7 +19,6 @@ def _manual_device_assert_triggered():
def test_device_assert_no_trigger():
@T.prim_func
def program():
with T.Kernel(threads=128):
......
......@@ -6,7 +6,6 @@ import tilelang.language as T
def debug_print_buffer(M=16, N=16, dtype="float16"):
@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
......@@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def test_debug_print_buffer():
debug_print_buffer(dtype='bool')
debug_print_buffer(dtype='int8')
debug_print_buffer(dtype='int16')
debug_print_buffer(dtype='int32')
debug_print_buffer(dtype='int64')
debug_print_buffer(dtype='uint8')
debug_print_buffer(dtype='uint16')
debug_print_buffer(dtype='uint32')
debug_print_buffer(dtype='uint64')
debug_print_buffer(dtype='float16')
debug_print_buffer(dtype='float32')
debug_print_buffer(dtype='float64')
debug_print_buffer(dtype='bfloat16')
debug_print_buffer(dtype='float8_e4m3')
debug_print_buffer(dtype='float8_e4m3fn')
debug_print_buffer(dtype='float8_e4m3fnuz')
debug_print_buffer(dtype='float8_e5m2')
debug_print_buffer(dtype='float8_e5m2fnuz')
debug_print_buffer(dtype="bool")
debug_print_buffer(dtype="int8")
debug_print_buffer(dtype="int16")
debug_print_buffer(dtype="int32")
debug_print_buffer(dtype="int64")
debug_print_buffer(dtype="uint8")
debug_print_buffer(dtype="uint16")
debug_print_buffer(dtype="uint32")
debug_print_buffer(dtype="uint64")
debug_print_buffer(dtype="float16")
debug_print_buffer(dtype="float32")
debug_print_buffer(dtype="float64")
debug_print_buffer(dtype="bfloat16")
debug_print_buffer(dtype="float8_e4m3")
debug_print_buffer(dtype="float8_e4m3fn")
debug_print_buffer(dtype="float8_e4m3fnuz")
debug_print_buffer(dtype="float8_e5m2")
debug_print_buffer(dtype="float8_e5m2fnuz")
def debug_print_buffer_conditional(M=16, N=16):
......
......@@ -5,7 +5,7 @@ import tilelang.testing
from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter)
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
tilelang.testing.set_random_seed(0)
......@@ -96,12 +96,11 @@ def tl_matmul_macro(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -109,10 +108,12 @@ def tl_matmul_macro(
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -120,7 +121,6 @@ def tl_matmul_macro(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -130,7 +130,6 @@ def tl_matmul_macro(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......@@ -207,8 +206,7 @@ def tl_matmul_block(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
......@@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
......@@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
)
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0,
tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment
tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment,
}
if M % 64 == 0 or N % 64 == 0 or K % 64 != 0:
# workaround for hopper tma lower pass
......@@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro():
def test_assert_tl_matmul_block():
assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic():
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16",
"float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
128,
128,
128,
False,
False,
"float16",
"float16",
"float16",
64,
64,
32,
dynamic_alignment=8)
128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64,
128,
128,
False,
False,
"float16",
"float16",
"float16",
64,
64,
32,
dynamic_alignment=8)
64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4)
64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4
)
# Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0)
64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0
)
if __name__ == "__main__":
......
......@@ -25,10 +25,8 @@ def tl_matmul_block_static(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk(
def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16",
"float16", "float32")
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", "float16", "float32")
def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
......@@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16",
"float16",
"float32",
pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
assert_tl_matmul_block_dynamic_m(
M,
N,
......@@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16",
"float16",
"float32",
pass_configs={"tl.disable_dynamic_tail_split": False})
pass_configs={"tl.disable_dynamic_tail_split": False},
)
def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
......@@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16",
"float16",
"float32",
pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
})
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
assert_tl_matmul_block_dynamic_mn(
M,
N,
......@@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16",
"float16",
"float32",
pass_configs={"tl.disable_dynamic_tail_split": False})
pass_configs={"tl.disable_dynamic_tail_split": False},
)
def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
......@@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16",
"float16",
"float32",
pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 4
})
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4},
)
assert_tl_matmul_block_dynamic_mnk(
M,
N,
......@@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16",
"float16",
"float32",
pass_configs={"tl.disable_dynamic_tail_split": False})
pass_configs={"tl.disable_dynamic_tail_split": False},
)
def test_all():
......
......@@ -7,16 +7,16 @@ import re
def get_mathop_lines(source, mathop_name):
"""Extract lines containing the mathop from CUDA source for debugging"""
lines = source.split('\n')
lines = source.split("\n")
relevant_lines = []
for i, line in enumerate(lines):
if mathop_name in line and ('(' in line):
if mathop_name in line and ("(" in line):
# Include some context
start = max(0, i - 1)
end = min(len(lines), i + 2)
relevant_lines.extend([f"{j}: {lines[j]}" for j in range(start, end)])
relevant_lines.append("---")
return '\n'.join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
return "\n".join(relevant_lines[-10:]) # Show last 10 lines to avoid too much output
def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
......@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches = re.findall(fastmath_pattern, source)
non_fastmath_matches = re.findall(non_fastmath_pattern, source)
print(
f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls"
)
print(f"Found {len(fastmath_matches)} fastmath calls, {len(non_fastmath_matches)} non-fastmath calls")
if len(fastmath_matches) > 0:
print(f"Fastmath calls found: {fastmath_matches}")
if len(non_fastmath_matches) > 0:
......@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
......@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
......@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
......@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
B[by * block_M + i, bx * block_N + j])
C[by * block_M + i, bx * block_N + j] = mathop_func(
A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j]
)
# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
......@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
......@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
source_no_fastmath = kernel_no_fastmath.get_kernel_source()
source_fastmath = kernel_fastmath.get_kernel_source()
......@@ -171,8 +159,8 @@ def run_abs_test():
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
......@@ -184,7 +172,8 @@ def run_abs_test():
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})
},
)
source = kernel.get_kernel_source()
print("\n=== Testing abs (maps to fabs) ===")
......@@ -199,26 +188,19 @@ def run_abs_test():
print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j])
# Test with FAST_MATH enabled
kernel_fastmath = tilelang.compile(
......@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
source_fastmath = kernel_fastmath.get_kernel_source()
print(f"\n=== Testing {mathop_name} (fastmath version) ===")
print("FAST_MATH=True:")
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name = mathop_name.lstrip('_')
cuda_mathop_name = mathop_name.lstrip("_")
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness
......
......@@ -8,14 +8,15 @@ from tilelang import language as T
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
},
)
def _cumsum_view_infer_layout(hidden):
num_tokens = T.dynamic('num_tokens')
num_tokens = T.dynamic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden,), dtype='float')
smem = T.alloc_shared((hidden,), dtype="float")
T.copy(x[pid, :], smem)
T.cumsum(T.view(smem, (1, hidden)), dim=1)
......@@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden):
def test_cumsum_view_infer_layout():
hidden = 128
x = torch.randn(1, hidden, device='cuda', dtype=torch.float)
x = torch.randn(1, hidden, device="cuda", dtype=torch.float)
kernel = _cumsum_view_infer_layout(hidden)
kernel(x)
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -8,12 +8,13 @@ from tilelang import language as T
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
},
)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')
num_tokens = T.symbolic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0)
......@@ -24,14 +25,15 @@ def _fill_with_static_region_kernel():
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
},
)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')
num_tokens = T.symbolic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int')
a, b = T.alloc_var("int"), T.alloc_var("int")
T.fill(x[a:b], 0)
return buggy_kernel
......@@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel():
def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
x = torch.zeros((256,), dtype=torch.int64, device="cuda")
kernel(x)
def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
x = torch.zeros((256,), dtype=torch.int64, device="cuda")
kernel(x)
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -4,25 +4,23 @@ import tilelang.language as T
def test_int64_address():
@tilelang.jit
def set_cache_kernel(
S,
D,
pos_ty='int64',
pos_ty="int64",
dtype="float32",
):
@T.prim_func
def main(
pos: T
.Tensor(
pos: T.Tensor(
[
S,
], pos_ty
],
pos_ty,
), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value: T.Tensor([S, D], dtype), # type: ignore
cache: T.Tensor([S, D], dtype), # type: ignore
value: T.Tensor([S, D], dtype), # type: ignore
cache: T.Tensor([S, D], dtype), # type: ignore
):
with T.Kernel(S, threads=128) as bx:
slot = pos[bx]
......@@ -34,11 +32,11 @@ def test_int64_address():
D = 2
S = 10
cache = torch.rand((S, D), device="cuda", dtype=torch.float32)
value = torch.rand((S, D), device='cuda', dtype=torch.float32)
pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64)
pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, 'int64')
kernel_int32 = set_cache_kernel(S, D, 'int32')
value = torch.rand((S, D), device="cuda", dtype=torch.float32)
pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64)
pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, "int64")
kernel_int32 = set_cache_kernel(S, D, "int32")
kernel_int64(pos_int64, value, cache)
torch.testing.assert_close(cache, value)
kernel_int32(pos_int32, value, cache)
......
......@@ -3,13 +3,17 @@ import tilelang.language as T
def test_issue_1198():
@T.prim_func
def foo(x: T.Buffer([
32,
], "int32")):
def foo(
x: T.Buffer(
[
32,
],
"int32",
),
):
pass
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -6,11 +6,10 @@ import torch
@tilelang.jit
def _tmp_var_kernel(N, block_N, dtype="float"):
@T.prim_func
def kernel(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:
for i in T.Parallel(block_N):
......
......@@ -8,7 +8,6 @@ import tilelang.language as T
@tilelang.jit
def _empty_kernel():
@T.prim_func
def empty_kernel():
with T.Kernel(1, threads=32) as thread_idx:
......@@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel():
@tilelang.jit
def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):
@T.prim_func
def kernel_with_tuple_kernel_binding():
with T.Kernel(1, threads=32) as (pid,):
......
......@@ -5,18 +5,16 @@ import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......
......@@ -6,7 +6,6 @@ import tilelang.language as T
def merge_if_test():
@T.prim_func
def main():
A = T.alloc_fragment((1,), "float16")
......
......@@ -29,9 +29,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -141,9 +141,9 @@ def matmu_jit_kernel(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......
......@@ -31,9 +31,9 @@ def matmul_kernel_jit(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -96,6 +96,7 @@ def run_gemm_kernel_jit(
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......
......@@ -28,9 +28,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -138,9 +138,9 @@ def matmu_jit_kernel(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C
......@@ -235,19 +236,9 @@ def test_gemm_jit_kernel():
)
def run_cython_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
def run_cython_kernel_do_bench(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M,
def test_cython_kernel_do_bench():
run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
def run_cython_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_cython_kernel_multi_stream(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M,
def test_cython_kernel_multi_stream():
run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)
def run_cython_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_cython_dynamic_shape(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M,
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cython_dynamic_shape():
run_cython_dynamic_shape(
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_cython_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
def run_cython_dynamic_shape_with_out_idx(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_cython_dynamic_shape_with_out_idx(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M,
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx(
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def matmul_int_variable(
......@@ -498,10 +449,10 @@ def matmul_int_variable(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
offset: T.int32,
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
offset: T.int32,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -525,10 +476,10 @@ def matmul_int_variable(
return main
def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads):
program = matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads)
def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads):
program = matmul_int_variable(
M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads
)
matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2)
in_dtype = map_torch_type(in_dtype)
......@@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B
def test_matmul_int_variable():
run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16",
"float32", 0, 128)
run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128)
def matmul_float_variable(
......@@ -570,10 +520,10 @@ def matmul_float_variable(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
offset: T.float32,
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
offset: T.float32,
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -597,10 +547,10 @@ def matmul_float_variable(
return main
def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads):
program = matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype,
out_dtype, dtypeAccum, num_stages, threads)
def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads):
program = matmul_float_variable(
M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, num_stages, threads
)
matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=2)
in_dtype = map_torch_type(in_dtype)
......@@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans
def test_matmul_float_variable():
run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16",
"float32", 0, 128)
run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment