Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
208 additions
and
263 deletions
+208
-263
testing/python/language/test_tilelang_language_negative_index.py
.../python/language/test_tilelang_language_negative_index.py
+1
-2
testing/python/language/test_tilelang_language_parallel.py
testing/python/language/test_tilelang_language_parallel.py
+5
-7
testing/python/language/test_tilelang_language_pipeline.py
testing/python/language/test_tilelang_language_pipeline.py
+19
-29
testing/python/language/test_tilelang_language_ptr.py
testing/python/language/test_tilelang_language_ptr.py
+0
-1
testing/python/language/test_tilelang_language_reduce.py
testing/python/language/test_tilelang_language_reduce.py
+10
-10
testing/python/language/test_tilelang_language_reshape.py
testing/python/language/test_tilelang_language_reshape.py
+31
-23
testing/python/language/test_tilelang_language_ternary.py
testing/python/language/test_tilelang_language_ternary.py
+6
-6
testing/python/language/test_tilelang_language_tma_1d.py
testing/python/language/test_tilelang_language_tma_1d.py
+2
-4
testing/python/language/test_tilelang_language_unroll.py
testing/python/language/test_tilelang_language_unroll.py
+0
-2
testing/python/language/test_tilelang_language_var_init.py
testing/python/language/test_tilelang_language_var_init.py
+5
-7
testing/python/language/test_tilelang_language_vectorize.py
testing/python/language/test_tilelang_language_vectorize.py
+6
-10
testing/python/language/test_tilelang_language_vectorized_cast.py
...python/language/test_tilelang_language_vectorized_cast.py
+5
-6
testing/python/language/test_tilelang_language_view.py
testing/python/language/test_tilelang_language_view.py
+7
-5
testing/python/language/test_tilelang_language_warp_reduce.py
...ing/python/language/test_tilelang_language_warp_reduce.py
+10
-11
testing/python/layout/test_tilelang_layout_fused_replicate.py
...ing/python/layout/test_tilelang_layout_fused_replicate.py
+6
-7
testing/python/math/test_math_bitwise_reduce.py
testing/python/math/test_math_bitwise_reduce.py
+5
-6
testing/python/math/test_math_fast_math.py
testing/python/math/test_math_fast_math.py
+32
-49
testing/python/math/test_math_ieee_math.py
testing/python/math/test_math_ieee_math.py
+26
-33
testing/python/metal/test_metal_codegen.py
testing/python/metal/test_metal_codegen.py
+14
-15
testing/python/primitives/test_tilelang_primitives_mma.py
testing/python/primitives/test_tilelang_primitives_mma.py
+18
-30
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/language/test_tilelang_language_negative_index.py
View file @
29051439
...
...
@@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,)
@
T
.
prim_func
def
negative_index_symbolic_before
(
shift
:
T
.
int32
,
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
def
negative_index_symbolic_before
(
shift
:
T
.
int32
,
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
16
):
B
[
i
]
=
A
[
shift
+
i
]
...
...
testing/python/language/test_tilelang_language_parallel.py
View file @
29051439
...
...
@@ -9,11 +9,10 @@ tilelang.testing.set_random_seed()
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
parallel_elementwise_static
(
length
=
256
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
):
...
...
@@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
parallel_elementwise_dynamic
(
max_len
=
512
,
threads
=
256
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
max_len
,),
dtype
),
B
:
T
.
Tensor
((
max_len
,),
dtype
),
valid_len
:
T
.
int32
,
A
:
T
.
Tensor
((
max_len
,),
dtype
),
B
:
T
.
Tensor
((
max_len
,),
dtype
),
valid_len
:
T
.
int32
,
):
with
T
.
Kernel
(
1
,
threads
=
threads
)
as
_
:
for
i
in
T
.
Parallel
(
max_len
):
...
...
testing/python/language/test_tilelang_language_pipeline.py
View file @
29051439
...
...
@@ -27,9 +27,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -90,7 +90,8 @@ def run_gemm(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
@@ -103,8 +104,8 @@ def run_gemm(
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
...
...
@@ -124,27 +125,19 @@ def test_pipeline_order_stage():
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
},
)
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
import
tilelang.language
as
T
@
T
.
prim_func
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages):
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
kernel
=
blocksparse_matmul
(
M
,
N
,
K
,
block_M
=
block_M
,
block_N
=
block_N
,
block_K
=
block_K
,
num_stages
=
num_stages
)
kernel
=
blocksparse_matmul
(
M
,
N
,
K
,
block_M
=
block_M
,
block_N
=
block_N
,
block_K
=
block_K
,
num_stages
=
num_stages
)
print
(
kernel
.
get_kernel_source
())
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
...
...
@@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
(
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
))
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
# Compute the reference result using the naive PyTorch implementation
...
...
testing/python/language/test_tilelang_language_ptr.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type
def
matmul_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
a_ptr
:
T
.
ptr
,
...
...
testing/python/language/test_tilelang_language_reduce.py
View file @
29051439
...
...
@@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
)
...
...
@@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
...
@@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
...
@@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
...
@@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
...
testing/python/language/test_tilelang_language_reshape.py
View file @
29051439
...
...
@@ -10,8 +10,8 @@ def reshape_test(N, M, dtype):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_reshaped
=
T
.
reshape
(
A
,
[
N
//
M
,
M
])
...
...
@@ -30,7 +30,8 @@ def run_reshape(N, M, dtype):
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
...
...
@@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
,),
dtype
)
...
...
@@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
...
...
@@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
)
...
...
@@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
...
...
@@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
...
...
@@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype):
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
...
...
@@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
T
.
annotate_layout
({
A_shared
:
make_mma_swizzle_layout
(
A_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
make_mma_swizzle_layout
(
A_shared
),
}
)
T
.
copy
(
A
,
A_shared
)
A_shared_reshape
=
T
.
reshape
(
A_shared
,
[
N
])
T
.
copy
(
A_shared_reshape
,
B
)
...
...
@@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype):
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
...
...
@@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
,),
dtype
,
scope
=
"shared"
)
...
...
@@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype):
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
...
...
@@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_reshaped
=
T
.
reshape
(
A
,
[
N
//
M
,
M
+
1
])
...
...
testing/python/language/test_tilelang_language_ternary.py
View file @
29051439
...
...
@@ -4,19 +4,19 @@ import torch
import
tilelang.testing
@
tilelang
.
jit
(
out_idx
=
[
1
],)
@
tilelang
.
jit
(
out_idx
=
[
1
],
)
def
tilelang_ternary
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
if
(
by
*
block_M
+
i
)
<
(
M
//
2
)
else
0
)
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
if
(
by
*
block_M
+
i
)
<
(
M
//
2
)
else
0
return
main
...
...
testing/python/language/test_tilelang_language_tma_1d.py
View file @
29051439
...
...
@@ -9,10 +9,8 @@ def ref_program(x, y):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
elementwise_add
(
M
,
N
,
block_M
,
block_N
,
in_dtype
,
out_dtype
,
threads
):
@
T
.
prim_func
def
elem_add
(
A
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
B
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
def
elem_add
(
A
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
B
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
in_dtype
)
B_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
in_dtype
)
...
...
@@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
T
.
copy
(
A
[
by
*
block_M
,
bx
*
block_N
],
A_shared
)
T
.
copy
(
B
[
by
*
block_M
,
bx
*
block_N
],
B_shared
)
for
(
local_y
,
local_x
)
in
T
.
Parallel
(
block_M
,
block_N
):
for
local_y
,
local_x
in
T
.
Parallel
(
block_M
,
block_N
):
C_local
[
local_y
,
local_x
]
=
A_shared
[
local_y
,
local_x
]
+
B_shared
[
local_y
,
local_x
]
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
testing/python/language/test_tilelang_language_unroll.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ from tilelang import language as T
def
test_unroll_with_step
():
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
...
...
@@ -19,7 +18,6 @@ def test_unroll_with_step():
def
test_unroll_with_unroll_factor
():
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
...
...
testing/python/language/test_tilelang_language_var_init.py
View file @
29051439
...
...
@@ -4,17 +4,15 @@ import tilelang.testing
def
test_var_assign
()
->
None
:
@
tilelang
.
jit
(
out_idx
=-
1
)
def
jit_kernel
():
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
'
int32
'
)):
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
"
int32
"
)):
with
T
.
Kernel
(
1
)
as
_
:
a
=
T
.
alloc_var
(
'
int32
'
,
init
=
1
)
b
=
T
.
alloc_var
(
'
int32
'
,
init
=
a
)
# b gets value of a
a
=
T
.
alloc_var
(
"
int32
"
,
init
=
1
)
b
=
T
.
alloc_var
(
"
int32
"
,
init
=
a
)
# b gets value of a
a
=
2
d
=
T
.
alloc_var
(
'
int32
'
,
init
=
a
)
# c gets new value of a
d
=
T
.
alloc_var
(
"
int32
"
,
init
=
a
)
# c gets new value of a
A
[
0
]
=
b
A
[
1
]
=
d
...
...
@@ -28,5 +26,5 @@ def test_var_assign() -> None:
assert
res
[
1
]
==
2
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_vectorize.py
View file @
29051439
...
...
@@ -5,11 +5,10 @@ import tilelang.language as T
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_VECTORIZE_256
:
True
})
def
vectorize_test
(
N
,
M
,
stride_A
,
stride_B
):
@
T
.
prim_func
def
main
(
A
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_A
),
"float32"
],
# noqa: F821
B
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_B
),
"float32"
],
# noqa: F821
A
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_A
),
"float32"
],
# noqa: F821
B
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_B
),
"float32"
],
# noqa: F821
):
with
T
.
Kernel
(
M
//
128
,
threads
=
128
)
as
(
bx
):
tx
=
T
.
get_thread_binding
(
0
)
...
...
@@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B):
code
=
jit_kernel
.
get_kernel_source
()
vectorize_size
=
1
while
vectorize_size
<=
2
and
\
stride_A
%
(
vectorize_size
*
2
)
==
0
and
\
stride_B
%
(
vectorize_size
*
2
)
==
0
:
while
vectorize_size
<=
2
and
stride_A
%
(
vectorize_size
*
2
)
==
0
and
stride_B
%
(
vectorize_size
*
2
)
==
0
:
vectorize_size
*=
2
if
vectorize_size
==
4
:
...
...
@@ -61,12 +58,11 @@ def test_vectorize():
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_VECTORIZE_256
:
True
})
def
vectorize_test_invariant_index
(
N
,
M
,
K
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
B
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
C
:
T
.
Tensor
[(
N
,
M
//
K
),
"float32"
],
# noqa: F821
A
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
B
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
C
:
T
.
Tensor
[(
N
,
M
//
K
),
"float32"
],
# noqa: F821
):
with
T
.
Kernel
(
N
//
128
,
threads
=
128
)
as
(
bx
):
tx
=
T
.
get_thread_binding
(
0
)
...
...
testing/python/language/test_tilelang_language_vectorized_cast.py
View file @
29051439
...
...
@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
):
with
T
.
Kernel
(
1
,
threads
=
128
):
T
.
copy
(
A
,
B
)
...
...
@@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
):
with
T
.
Kernel
(
1
,
threads
=
128
):
A_local
=
T
.
alloc_fragment
((
M
,),
dtype_A
)
...
...
@@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
code
=
kernel
.
get_kernel_source
()
code_parallel
=
kernel_parallel
.
get_kernel_source
()
assert
check_str
in
code
and
check_str
in
code_parallel
,
\
f
"Cast
{
src_dtype_str
}
to
{
dst_dtype_str
}
with
{
lanes
=
}
is not vectorized!"
assert
check_str
in
code
and
check_str
in
code_parallel
,
f
"Cast
{
src_dtype_str
}
to
{
dst_dtype_str
}
with
{
lanes
=
}
is not vectorized!"
def
test_vectorized_cast
():
...
...
testing/python/language/test_tilelang_language_view.py
View file @
29051439
...
...
@@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None):
new_shape
=
[
N
//
M
,
M
]
if
new_dtype
:
from
tvm
import
DataType
dtype_src
=
DataType
(
dtype
)
dtype_dst
=
DataType
(
new_dtype
)
src_bits
=
dtype_src
.
bits
...
...
@@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_viewed
=
T
.
view
(
A
,
new_shape
,
dtype
=
new_dtype
)
...
...
@@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None):
def
ref_program
(
A
):
if
new_dtype
:
from
tilelang.utils.tensor
import
map_torch_type
torch_dtype
=
map_torch_type
(
new_dtype
)
return
A
.
view
(
N
//
M
,
M
).
view
(
dtype
=
torch_dtype
)
return
A
.
view
(
N
//
M
,
M
)
...
...
@@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None):
def
test_reshape_view
():
# Test view with same dtype
run_view
(
1024
,
32
,
"float32"
)
run_view
(
2048
,
64
,
"float16"
)
...
...
@@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
new_shape
=
[
N
//
M
,
M
+
1
]
if
new_dtype
:
from
tvm
import
DataType
dtype_src
=
DataType
(
dtype
)
dtype_dst
=
DataType
(
new_dtype
)
src_bits
=
dtype_src
.
bits
...
...
@@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_viewed
=
T
.
view
(
A
,
new_shape
,
dtype
=
new_dtype
)
...
...
testing/python/language/test_tilelang_language_warp_reduce.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ import tilelang.language as T
@
tilelang
.
jit
def
get_kernel
(
reduce_op
:
str
,
dtype
:
str
):
assert
reduce_op
in
[
"sum"
,
"max"
,
"min"
,
"bitand"
,
"bitor"
]
@
T
.
prim_func
...
...
@@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str):
def
test_warp_reduce_sum
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
kernel
=
get_kernel
(
'
sum
'
,
'
float32
'
)
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"
sum
"
,
"
float32
"
)
ref
=
torch
.
full_like
(
a
,
a
.
sum
())
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_max
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
kernel
=
get_kernel
(
"max"
,
'
float32
'
)
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"max"
,
"
float32
"
)
print
(
kernel
.
get_kernel_source
())
ref
=
torch
.
full_like
(
a
,
a
.
max
())
kernel
(
a
)
...
...
@@ -50,16 +49,16 @@ def test_warp_reduce_max():
def
test_warp_reduce_min
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
kernel
=
get_kernel
(
"min"
,
'
float32
'
)
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"min"
,
"
float32
"
)
ref
=
torch
.
full_like
(
a
,
a
.
min
())
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_bitand
():
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
kernel
=
get_kernel
(
"bitand"
,
'
int32
'
)
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"bitand"
,
"
int32
"
)
ref_val
=
a
[
0
]
for
i
in
range
(
1
,
a
.
shape
[
0
]):
ref_val
=
ref_val
&
a
[
i
]
...
...
@@ -69,8 +68,8 @@ def test_warp_reduce_bitand():
def
test_warp_reduce_bitor
():
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
kernel
=
get_kernel
(
"bitor"
,
'
int32
'
)
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"bitor"
,
"
int32
"
)
ref_val
=
a
[
0
]
for
i
in
range
(
1
,
a
.
shape
[
0
]):
ref_val
=
ref_val
|
a
[
i
]
...
...
testing/python/layout/test_tilelang_layout_fused_replicate.py
View file @
29051439
...
...
@@ -12,17 +12,16 @@ VEC_SIZE = 32
@
tilelang
.
jit
def
fused_index_kernel
(
B
:
int
,
M
:
int
,
N
:
int
,
BLOCK_MN
:
int
,
BLOCK_K
:
int
):
@
T
.
prim_func
def
main
(
a
:
T
.
Buffer
((
B
,
M
,
N
),
"bfloat16"
),
a_out
:
T
.
Buffer
((
B
,
M
,
N
),
"float32"
),
a
:
T
.
Buffer
((
B
,
M
,
N
),
"bfloat16"
),
a_out
:
T
.
Buffer
((
B
,
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
BLOCK_MN
),
T
.
ceildiv
(
N
,
BLOCK_K
),
B
,
threads
=
128
,
T
.
ceildiv
(
M
,
BLOCK_MN
),
T
.
ceildiv
(
N
,
BLOCK_K
),
B
,
threads
=
128
,
)
as
(
pid_m
,
pid_n
,
pid_b
):
a_fp32_local
=
T
.
alloc_fragment
((
BLOCK_MN
*
BLOCK_K
//
VEC_SIZE
,
VEC_SIZE
),
"float32"
)
offs_m
=
pid_m
*
BLOCK_MN
...
...
testing/python/math/test_math_bitwise_reduce.py
View file @
29051439
...
...
@@ -19,12 +19,11 @@ def bitwise_reduce(
func
,
clear
=
True
,
):
@
T
.
prim_func
def
reduce_func
(
A
:
T
.
Tensor
((
M
,
N
),
"int32"
),
B
:
T
.
Tensor
((
M
),
"int32"
),
Output
:
T
.
Tensor
((
M
),
"int32"
),
A
:
T
.
Tensor
((
M
,
N
),
"int32"
),
B
:
T
.
Tensor
((
M
),
"int32"
),
Output
:
T
.
Tensor
((
M
),
"int32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
"int32"
)
...
...
@@ -64,7 +63,7 @@ def run_single_bitwise_reduce(
row_pattern
=
(
i
&
0xF
)
<<
(
i
%
4
)
# 4-bit patterns shifted by row
# Column-based pattern: different bit positions set based on column
col_pattern
=
(
1
<<
(
j
%
31
)
)
# Single bit set at different positions
col_pattern
=
1
<<
(
j
%
31
)
# Single bit set at different positions
# Combine patterns with XOR to create diverse bit distributions
# Add some deterministic "noise" based on position
...
...
@@ -76,7 +75,7 @@ def run_single_bitwise_reduce(
if
i
%
4
==
0
:
a
[
i
,
j
]
&=
~
(
0x1
<<
(
i
//
4
))
elif
i
%
2
==
0
:
a
[
i
,
j
]
|=
(
0x1
<<
(
i
//
2
)
)
a
[
i
,
j
]
|=
0x1
<<
(
i
//
2
)
if
name
==
"reduce_bitand"
:
expected
=
torch
.
full
((
M
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
...
...
testing/python/math/test_math_fast_math.py
View file @
29051439
...
...
@@ -7,16 +7,16 @@ import re
def
get_mathop_lines
(
source
,
mathop_name
):
"""Extract lines containing the mathop from CUDA source for debugging"""
lines
=
source
.
split
(
'
\n
'
)
lines
=
source
.
split
(
"
\n
"
)
relevant_lines
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
mathop_name
in
line
and
(
'('
in
line
):
if
mathop_name
in
line
and
(
"("
in
line
):
# Include some context
start
=
max
(
0
,
i
-
1
)
end
=
min
(
len
(
lines
),
i
+
2
)
relevant_lines
.
extend
([
f
"
{
j
}
:
{
lines
[
j
]
}
"
for
j
in
range
(
start
,
end
)])
relevant_lines
.
append
(
"---"
)
return
'
\n
'
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
return
"
\n
"
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
def
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
):
...
...
@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches
=
re
.
findall
(
fastmath_pattern
,
source
)
non_fastmath_matches
=
re
.
findall
(
non_fastmath_pattern
,
source
)
print
(
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
print
(
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
if
len
(
fastmath_matches
)
>
0
:
print
(
f
"Fastmath calls found:
{
fastmath_matches
}
"
)
if
len
(
non_fastmath_matches
)
>
0
:
...
...
@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
)
def
run_single_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
def
run_single_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...
...
@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
...
...
@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
...
...
@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print
(
f
"✓
{
mathop_name
}
compilation and execution test passed"
)
def
run_two_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
def
run_two_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
)
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
)
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
...
...
@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
...
...
@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
...
...
@@ -171,8 +159,8 @@ def run_abs_test():
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -184,7 +172,8 @@ def run_abs_test():
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source
=
kernel
.
get_kernel_source
()
print
(
"
\n
=== Testing abs (maps to fabs) ==="
)
...
...
@@ -199,26 +188,19 @@ def run_abs_test():
print
(
"✓ abs numerical test passed"
)
def
run_fastmath_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
def
run_fastmath_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
...
...
@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
print
(
f
"
\n
=== Testing
{
mathop_name
}
(fastmath version) ==="
)
print
(
"FAST_MATH=True:"
)
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name
=
mathop_name
.
lstrip
(
'_'
)
cuda_mathop_name
=
mathop_name
.
lstrip
(
"_"
)
check_fastmath_usage
(
source_fastmath
,
cuda_mathop_name
,
expect_fastmath
=
True
)
# Test numerical correctness
...
...
testing/python/math/test_math_ieee_math.py
View file @
29051439
...
...
@@ -5,14 +5,7 @@ import tilelang.testing
import
pytest
def
run_ieee_math_test
(
mathop_name
,
mathop_func
,
rounding_mode
=
"rn"
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
def
run_ieee_math_test
(
mathop_name
,
mathop_func
,
rounding_mode
=
"rn"
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
Test IEEE-compliant math operations with specified rounding modes.
"""
...
...
@@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name,
@
T
.
prim_func
def
main_func
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
D
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
)
D
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
,
)
out_idx
=
[
3
]
num_inputs
=
3
...
...
@@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name,
@
T
.
prim_func
def
main_func
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
)
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
)
out_idx
=
[
2
]
num_inputs
=
2
...
...
@@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name,
@
T
.
prim_func
def
main_func
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
)
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
)
out_idx
=
[
1
]
num_inputs
=
1
...
...
@@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
print
(
f
"
\n
=== Testing
{
mathop_name
}
with rounding mode
{
rounding_mode
}
==="
)
print
(
f
"✓
{
mathop_name
}
compilation test passed"
)
...
...
@@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only():
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
),
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
128
,
32
),
T
.
ceildiv
(
128
,
32
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
32
,
32
):
...
...
@@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only():
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
print
(
"
\n
=== Testing ieee_frsqrt (rn only) ==="
)
print
(
"✓ ieee_frsqrt compilation test passed"
)
...
...
testing/python/metal/test_metal_codegen.py
View file @
29051439
...
...
@@ -5,18 +5,17 @@ import tilelang.language as T
import
torch
@
tilelang
.
jit
(
execution_backend
=
'
torch
'
)
@
tilelang
.
jit
(
execution_backend
=
"
torch
"
)
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float32"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
'
shared
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
,
scope
=
'
shared
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"
shared
"
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
,
scope
=
"
shared
"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
...
...
@@ -48,13 +47,13 @@ def assert_gemm(
torch_dtype
=
getattr
(
torch
,
dtype
)
a
,
b
=
None
,
None
if
'
int
'
in
dtype
:
a
=
torch
.
randint
(
100
,
(
M
,
K
),
dtype
=
torch_dtype
,
device
=
'
mps
'
)
b
=
torch
.
randint
(
100
,
(
K
,
N
),
dtype
=
torch_dtype
,
device
=
'
mps
'
)
if
"
int
"
in
dtype
:
a
=
torch
.
randint
(
100
,
(
M
,
K
),
dtype
=
torch_dtype
,
device
=
"
mps
"
)
b
=
torch
.
randint
(
100
,
(
K
,
N
),
dtype
=
torch_dtype
,
device
=
"
mps
"
)
else
:
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch_dtype
,
device
=
'
mps
'
)
b
=
torch
.
randn
(
K
,
N
,
dtype
=
torch_dtype
,
device
=
'
mps
'
)
c
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch_dtype
,
device
=
'
mps
'
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch_dtype
,
device
=
"
mps
"
)
b
=
torch
.
randn
(
K
,
N
,
dtype
=
torch_dtype
,
device
=
"
mps
"
)
c
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch_dtype
,
device
=
"
mps
"
)
jit_kernel
(
a
,
b
,
c
)
...
...
@@ -70,12 +69,12 @@ def test_gemm_float32():
@
tilelang
.
testing
.
requires_metal
def
test_gemm_float16
():
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
'
float16
'
,
atol
=
1
)
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
"
float16
"
,
atol
=
1
)
@
tilelang
.
testing
.
requires_metal
def
test_gemm_int32
():
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
'
int32
'
,
atol
=
1
)
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
"
int32
"
,
atol
=
1
)
if
__name__
==
"__main__"
:
...
...
testing/python/primitives/test_tilelang_primitives_mma.py
View file @
29051439
...
...
@@ -27,9 +27,9 @@ def matmul_ssr(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
...
...
@@ -88,7 +88,8 @@ def run_matmul_ssr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
@@ -106,24 +107,9 @@ def run_matmul_ssr(
def
test_gemm_f16f16f16_nt_ssr
():
run_matmul_ssr
(
16
,
16
,
16
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
16
,
16
,
16
,
0
,
num_threads
=
32
)
run_matmul_ssr
(
128
,
128
,
128
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
32
,
32
,
32
,
0
,
num_threads
=
64
)
run_matmul_ssr
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
128
,
128
,
32
,
2
,
num_threads
=
128
)
run_matmul_ssr
(
16
,
16
,
16
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
16
,
16
,
16
,
0
,
num_threads
=
32
)
run_matmul_ssr
(
128
,
128
,
128
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
32
,
32
,
32
,
0
,
num_threads
=
64
)
run_matmul_ssr
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
128
,
128
,
32
,
2
,
num_threads
=
128
)
def
matmul_rsr
(
...
...
@@ -151,9 +137,9 @@ def matmul_rsr(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
...
...
@@ -214,7 +200,8 @@ def run_matmul_rsr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
@@ -276,9 +263,9 @@ def matmul_rrr(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -342,7 +329,8 @@ def run_matmul_rrr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment