Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
228 additions
and
344 deletions
+228
-344
testing/python/carver/test_tilelang_carver_recommend_hints.py
...ing/python/carver/test_tilelang_carver_recommend_hints.py
+5
-12
testing/python/components/test_storage_rewrite_detect_inplace.py
.../python/components/test_storage_rewrite_detect_inplace.py
+2
-1
testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py
...nts/test_tilelang_pass_config_disable_warp_specialized.py
+5
-4
testing/python/cpu/test_tilelang_cpu_gemm.py
testing/python/cpu/test_tilelang_cpu_gemm.py
+6
-8
testing/python/debug/test_device_assert.py
testing/python/debug/test_device_assert.py
+0
-2
testing/python/debug/test_tilelang_debug_print.py
testing/python/debug/test_tilelang_debug_print.py
+18
-19
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
+27
-54
testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py
...ng/python/dynamic/test_tilelang_dynamic_symbolic_bench.py
+21
-33
testing/python/fastmath/test_mathops_fastmath.py
testing/python/fastmath/test_mathops_fastmath.py
+32
-49
testing/python/issue/test_tilelang_issue_1001.py
testing/python/issue/test_tilelang_issue_1001.py
+7
-6
testing/python/issue/test_tilelang_issue_1008.py
testing/python/issue/test_tilelang_issue_1008.py
+12
-10
testing/python/issue/test_tilelang_issue_1115.py
testing/python/issue/test_tilelang_issue_1115.py
+11
-13
testing/python/issue/test_tilelang_issue_1198.py
testing/python/issue/test_tilelang_issue_1198.py
+9
-5
testing/python/issue/test_tilelang_issue_814.py
testing/python/issue/test_tilelang_issue_814.py
+2
-3
testing/python/issue/test_tilelang_issue_830.py
testing/python/issue/test_tilelang_issue_830.py
+0
-2
testing/python/issue/test_tilelang_issue_96.py
testing/python/issue/test_tilelang_issue_96.py
+7
-9
testing/python/issue/test_tilelang_issue_merge_if.py
testing/python/issue/test_tilelang_issue_merge_if.py
+0
-1
testing/python/jit/test_tilelang_jit_callback.py
testing/python/jit/test_tilelang_jit_callback.py
+7
-6
testing/python/jit/test_tilelang_jit_gemm.py
testing/python/jit/test_tilelang_jit_gemm.py
+4
-3
testing/python/jit/test_tilelang_jit_gemm_cython.py
testing/python/jit/test_tilelang_jit_gemm_cython.py
+53
-104
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/carver/test_tilelang_carver_recommend_hints.py
View file @
29051439
...
...
@@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch
from
typing
import
List
def
run_general_reduction_recommend_hints
(
structure
:
str
=
"SSR"
,
shape
:
List
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
def
run_general_reduction_recommend_hints
(
structure
:
str
=
"SSR"
,
shape
:
List
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
arch
=
auto_infer_current_arch
()
carve_template
=
carver
.
GeneralReductionTemplate
(
structure
=
structure
,
...
...
@@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints
(
"SRS"
,
[
1024
,
1024
,
1024
],
"float16"
)
def
run_elementwise_recommend_hints
(
shape
:
List
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
def
run_elementwise_recommend_hints
(
shape
:
List
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
arch
=
auto_infer_current_arch
()
carve_template
=
carver
.
ElementwiseTemplate
(
shape
=
shape
,
...
...
@@ -81,11 +76,9 @@ def test_matmul_recommend_hints():
run_matmul_recommend_hints
(
1024
,
1024
,
1024
,
"float16"
,
"float32"
,
"float16"
)
def
run_gemv_recommend_hints
(
N
:
int
=
1024
,
K
:
int
=
1024
,
in_dtype
:
str
=
"float16"
,
out_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
):
def
run_gemv_recommend_hints
(
N
:
int
=
1024
,
K
:
int
=
1024
,
in_dtype
:
str
=
"float16"
,
out_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
):
arch
=
auto_infer_current_arch
()
carve_template
=
carver
.
GEMVTemplate
(
N
=
N
,
...
...
testing/python/components/test_storage_rewrite_detect_inplace.py
View file @
29051439
...
...
@@ -23,7 +23,8 @@ def _compile_kernel_without_inplace():
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_STORAGE_REWRITE_DETECT_INPLACE
:
True
,
},)
},
)
def
_compile_kernel_with_inplace
():
num_tokens
=
T
.
symbolic
(
"num_tokens"
)
...
...
testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py
View file @
29051439
...
...
@@ -26,9 +26,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -88,7 +88,8 @@ def run_gemm(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
disable_warp_specialized
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
testing/python/cpu/test_tilelang_cpu_gemm.py
View file @
29051439
...
...
@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@
T
.
prim_func
def
matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
is_cpu
=
True
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_local
((
block_M
,
block_K
),
dtype
)
...
...
@@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# )
for
ko
in
T
.
Pipelined
(
K
//
block_K
,
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_local
)
# Or Copy with Parallel
...
...
@@ -62,14 +61,13 @@ def test_matmul_codegen():
def
test_matmul_compile
():
def
matmul_jit_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
# a simple kernel just for jit test
@
T
.
prim_func
def
matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
is_cpu
=
True
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_local
((
block_M
,
block_K
),
dtype
)
...
...
testing/python/debug/test_device_assert.py
View file @
29051439
...
...
@@ -7,7 +7,6 @@ import tilelang.language as T
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# Please run manually when you want to verify that device_assert actually traps on GPU.
def
_manual_device_assert_triggered
():
@
T
.
prim_func
def
program
():
with
T
.
Kernel
(
threads
=
128
):
...
...
@@ -20,7 +19,6 @@ def _manual_device_assert_triggered():
def
test_device_assert_no_trigger
():
@
T
.
prim_func
def
program
():
with
T
.
Kernel
(
threads
=
128
):
...
...
testing/python/debug/test_tilelang_debug_print.py
View file @
29051439
...
...
@@ -6,7 +6,6 @@ import tilelang.language as T
def
debug_print_buffer
(
M
=
16
,
N
=
16
,
dtype
=
"float16"
):
@
T
.
prim_func
def
program
(
Q
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
4
,
4
,
2
,
threads
=
128
*
2
)
as
(
bx
,
by
,
bz
):
...
...
@@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def
test_debug_print_buffer
():
debug_print_buffer
(
dtype
=
'
bool
'
)
debug_print_buffer
(
dtype
=
'
int8
'
)
debug_print_buffer
(
dtype
=
'
int16
'
)
debug_print_buffer
(
dtype
=
'
int32
'
)
debug_print_buffer
(
dtype
=
'
int64
'
)
debug_print_buffer
(
dtype
=
'
uint8
'
)
debug_print_buffer
(
dtype
=
'
uint16
'
)
debug_print_buffer
(
dtype
=
'
uint32
'
)
debug_print_buffer
(
dtype
=
'
uint64
'
)
debug_print_buffer
(
dtype
=
'
float16
'
)
debug_print_buffer
(
dtype
=
'
float32
'
)
debug_print_buffer
(
dtype
=
'
float64
'
)
debug_print_buffer
(
dtype
=
'
bfloat16
'
)
debug_print_buffer
(
dtype
=
'
float8_e4m3
'
)
debug_print_buffer
(
dtype
=
'
float8_e4m3fn
'
)
debug_print_buffer
(
dtype
=
'
float8_e4m3fnuz
'
)
debug_print_buffer
(
dtype
=
'
float8_e5m2
'
)
debug_print_buffer
(
dtype
=
'
float8_e5m2fnuz
'
)
debug_print_buffer
(
dtype
=
"
bool
"
)
debug_print_buffer
(
dtype
=
"
int8
"
)
debug_print_buffer
(
dtype
=
"
int16
"
)
debug_print_buffer
(
dtype
=
"
int32
"
)
debug_print_buffer
(
dtype
=
"
int64
"
)
debug_print_buffer
(
dtype
=
"
uint8
"
)
debug_print_buffer
(
dtype
=
"
uint16
"
)
debug_print_buffer
(
dtype
=
"
uint32
"
)
debug_print_buffer
(
dtype
=
"
uint64
"
)
debug_print_buffer
(
dtype
=
"
float16
"
)
debug_print_buffer
(
dtype
=
"
float32
"
)
debug_print_buffer
(
dtype
=
"
float64
"
)
debug_print_buffer
(
dtype
=
"
bfloat16
"
)
debug_print_buffer
(
dtype
=
"
float8_e4m3
"
)
debug_print_buffer
(
dtype
=
"
float8_e4m3fn
"
)
debug_print_buffer
(
dtype
=
"
float8_e4m3fnuz
"
)
debug_print_buffer
(
dtype
=
"
float8_e5m2
"
)
debug_print_buffer
(
dtype
=
"
float8_e5m2fnuz
"
)
def
debug_print_buffer_conditional
(
M
=
16
,
N
=
16
):
...
...
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
View file @
29051439
...
...
@@ -5,7 +5,7 @@ import tilelang.testing
from
tvm
import
DataType
import
tilelang.language
as
T
from
tilelang.intrinsics.utils
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
)
from
tilelang.intrinsics.mma_macro_generator
import
TensorCoreIntrinEmitter
tilelang
.
testing
.
set_random_seed
(
0
)
...
...
@@ -96,12 +96,11 @@ def tl_matmul_macro(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -109,10 +108,12 @@ def tl_matmul_macro(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -120,7 +121,6 @@ def tl_matmul_macro(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -130,7 +130,6 @@ def tl_matmul_macro(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
...
...
@@ -207,8 +206,7 @@ def tl_matmul_block(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
...
@@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
...
@@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
)
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_DYNAMIC_TAIL_SPLIT
:
dynamic_alignment
!=
0
,
tilelang
.
PassConfigKey
.
TL_DYNAMIC_ALIGNMENT
:
dynamic_alignment
tilelang
.
PassConfigKey
.
TL_DYNAMIC_ALIGNMENT
:
dynamic_alignment
,
}
if
M
%
64
==
0
or
N
%
64
==
0
or
K
%
64
!=
0
:
# workaround for hopper tma lower pass
...
...
@@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro():
def
test_assert_tl_matmul_block
():
assert_tl_matmul_block_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
def
test_assert_tl_matmul_block_all_dynamic
():
assert_tl_matmul_block_all_dynamic_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
def
test_assert_tl_matmul_block_all_dynamic_with_pass_config
():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
)
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
64
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
)
64
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
64
,
128
,
60
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
4
)
64
,
128
,
60
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
4
)
# Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
64
,
128
,
64
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
0
)
64
,
128
,
64
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
0
)
if
__name__
==
"__main__"
:
...
...
testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py
View file @
29051439
...
...
@@ -25,10 +25,8 @@ def tl_matmul_block_static(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk(
def
run_assert_tl_matmul_block_static
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
assert_tl_matmul_block_static
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
)
assert_tl_matmul_block_static
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
)
def
run_assert_tl_matmul_block_dynamic_m
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
...
...
@@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
8
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
8
},
)
assert_tl_matmul_block_dynamic_m
(
M
,
N
,
...
...
@@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
},
)
def
run_assert_tl_matmul_block_dynamic_mn
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
...
...
@@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
8
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
8
},
)
assert_tl_matmul_block_dynamic_mn
(
M
,
N
,
...
...
@@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
},
)
def
run_assert_tl_matmul_block_dynamic_mnk
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
...
...
@@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
4
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
4
},
)
assert_tl_matmul_block_dynamic_mnk
(
M
,
N
,
...
...
@@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
},
)
def
test_all
():
...
...
testing/python/fastmath/test_mathops_fastmath.py
View file @
29051439
...
...
@@ -7,16 +7,16 @@ import re
def
get_mathop_lines
(
source
,
mathop_name
):
"""Extract lines containing the mathop from CUDA source for debugging"""
lines
=
source
.
split
(
'
\n
'
)
lines
=
source
.
split
(
"
\n
"
)
relevant_lines
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
mathop_name
in
line
and
(
'('
in
line
):
if
mathop_name
in
line
and
(
"("
in
line
):
# Include some context
start
=
max
(
0
,
i
-
1
)
end
=
min
(
len
(
lines
),
i
+
2
)
relevant_lines
.
extend
([
f
"
{
j
}
:
{
lines
[
j
]
}
"
for
j
in
range
(
start
,
end
)])
relevant_lines
.
append
(
"---"
)
return
'
\n
'
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
return
"
\n
"
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
def
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
):
...
...
@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches
=
re
.
findall
(
fastmath_pattern
,
source
)
non_fastmath_matches
=
re
.
findall
(
non_fastmath_pattern
,
source
)
print
(
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
print
(
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
if
len
(
fastmath_matches
)
>
0
:
print
(
f
"Fastmath calls found:
{
fastmath_matches
}
"
)
if
len
(
non_fastmath_matches
)
>
0
:
...
...
@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
)
def
run_single_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
def
run_single_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...
...
@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
...
...
@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
...
...
@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print
(
f
"✓
{
mathop_name
}
compilation and execution test passed"
)
def
run_two_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
def
run_two_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
)
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
)
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
...
...
@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
...
...
@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
...
...
@@ -171,8 +159,8 @@ def run_abs_test():
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
...
@@ -184,7 +172,8 @@ def run_abs_test():
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source
=
kernel
.
get_kernel_source
()
print
(
"
\n
=== Testing abs (maps to fabs) ==="
)
...
...
@@ -199,26 +188,19 @@ def run_abs_test():
print
(
"✓ abs numerical test passed"
)
def
run_fastmath_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
def
run_fastmath_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
...
...
@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
print
(
f
"
\n
=== Testing
{
mathop_name
}
(fastmath version) ==="
)
print
(
"FAST_MATH=True:"
)
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name
=
mathop_name
.
lstrip
(
'_'
)
cuda_mathop_name
=
mathop_name
.
lstrip
(
"_"
)
check_fastmath_usage
(
source_fastmath
,
cuda_mathop_name
,
expect_fastmath
=
True
)
# Test numerical correctness
...
...
testing/python/issue/test_tilelang_issue_1001.py
View file @
29051439
...
...
@@ -8,14 +8,15 @@ from tilelang import language as T
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
},
)
def
_cumsum_view_infer_layout
(
hidden
):
num_tokens
=
T
.
dynamic
(
'
num_tokens
'
)
num_tokens
=
T
.
dynamic
(
"
num_tokens
"
)
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,
hidden
),
'
float
'
]):
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,
hidden
),
"
float
"
]):
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
pid
:
smem
=
T
.
alloc_shared
((
hidden
,),
dtype
=
'
float
'
)
smem
=
T
.
alloc_shared
((
hidden
,),
dtype
=
"
float
"
)
T
.
copy
(
x
[
pid
,
:],
smem
)
T
.
cumsum
(
T
.
view
(
smem
,
(
1
,
hidden
)),
dim
=
1
)
...
...
@@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden):
def
test_cumsum_view_infer_layout
():
hidden
=
128
x
=
torch
.
randn
(
1
,
hidden
,
device
=
'
cuda
'
,
dtype
=
torch
.
float
)
x
=
torch
.
randn
(
1
,
hidden
,
device
=
"
cuda
"
,
dtype
=
torch
.
float
)
kernel
=
_cumsum_view_infer_layout
(
hidden
)
kernel
(
x
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1008.py
View file @
29051439
...
...
@@ -8,12 +8,13 @@ from tilelang import language as T
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
},
)
def
_fill_with_static_region_kernel
():
num_tokens
=
T
.
symbolic
(
'
num_tokens
'
)
num_tokens
=
T
.
symbolic
(
"
num_tokens
"
)
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
'
int64
'
]):
# noqa: F821
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
"
int64
"
]):
# noqa: F821
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
T
.
fill
(
x
[
0
:
128
],
0
)
...
...
@@ -24,14 +25,15 @@ def _fill_with_static_region_kernel():
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
},
)
def
_fill_with_dynamic_region_kernel
():
num_tokens
=
T
.
symbolic
(
'
num_tokens
'
)
num_tokens
=
T
.
symbolic
(
"
num_tokens
"
)
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
'
int64
'
]):
# noqa: F821
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
"
int64
"
]):
# noqa: F821
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
a
,
b
=
T
.
alloc_var
(
'
int
'
),
T
.
alloc_var
(
'
int
'
)
a
,
b
=
T
.
alloc_var
(
"
int
"
),
T
.
alloc_var
(
"
int
"
)
T
.
fill
(
x
[
a
:
b
],
0
)
return
buggy_kernel
...
...
@@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel():
def
test_fill_with_static_region_kernel
():
kernel
=
_fill_with_static_region_kernel
()
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
'
cuda
'
)
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
"
cuda
"
)
kernel
(
x
)
def
test_fill_with_dynamic_region_kernel
():
kernel
=
_fill_with_dynamic_region_kernel
()
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
'
cuda
'
)
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
"
cuda
"
)
kernel
(
x
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1115.py
View file @
29051439
...
...
@@ -4,25 +4,23 @@ import tilelang.language as T
def
test_int64_address
():
@
tilelang
.
jit
def
set_cache_kernel
(
S
,
D
,
pos_ty
=
'
int64
'
,
pos_ty
=
"
int64
"
,
dtype
=
"float32"
,
):
@
T
.
prim_func
def
main
(
pos
:
T
.
Tensor
(
pos
:
T
.
Tensor
(
[
S
,
],
pos_ty
],
pos_ty
,
),
# type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
cache
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
value
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
cache
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
):
with
T
.
Kernel
(
S
,
threads
=
128
)
as
bx
:
slot
=
pos
[
bx
]
...
...
@@ -34,11 +32,11 @@ def test_int64_address():
D
=
2
S
=
10
cache
=
torch
.
rand
((
S
,
D
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
value
=
torch
.
rand
((
S
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
pos_int64
=
torch
.
arange
(
S
,
device
=
'
cuda
'
,
dtype
=
torch
.
int64
)
pos_int32
=
torch
.
arange
(
S
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
kernel_int64
=
set_cache_kernel
(
S
,
D
,
'
int64
'
)
kernel_int32
=
set_cache_kernel
(
S
,
D
,
'
int32
'
)
value
=
torch
.
rand
((
S
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
pos_int64
=
torch
.
arange
(
S
,
device
=
"
cuda
"
,
dtype
=
torch
.
int64
)
pos_int32
=
torch
.
arange
(
S
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
kernel_int64
=
set_cache_kernel
(
S
,
D
,
"
int64
"
)
kernel_int32
=
set_cache_kernel
(
S
,
D
,
"
int32
"
)
kernel_int64
(
pos_int64
,
value
,
cache
)
torch
.
testing
.
assert_close
(
cache
,
value
)
kernel_int32
(
pos_int32
,
value
,
cache
)
...
...
testing/python/issue/test_tilelang_issue_1198.py
View file @
29051439
...
...
@@ -3,13 +3,17 @@ import tilelang.language as T
def
test_issue_1198
():
@
T
.
prim_func
def
foo
(
x
:
T
.
Buffer
([
32
,
],
"int32"
)):
def
foo
(
x
:
T
.
Buffer
(
[
32
,
],
"int32"
,
),
):
pass
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_814.py
View file @
29051439
...
...
@@ -6,11 +6,10 @@ import torch
@
tilelang
.
jit
def
_tmp_var_kernel
(
N
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
def
kernel
(
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
128
)
as
bx
:
for
i
in
T
.
Parallel
(
block_N
):
...
...
testing/python/issue/test_tilelang_issue_830.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ import tilelang.language as T
@
tilelang
.
jit
def
_empty_kernel
():
@
T
.
prim_func
def
empty_kernel
():
with
T
.
Kernel
(
1
,
threads
=
32
)
as
thread_idx
:
...
...
@@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel():
@
tilelang
.
jit
def
_empty_kernel_with_binding_variants
(
use_tuple_binding
:
bool
=
False
):
@
T
.
prim_func
def
kernel_with_tuple_kernel_binding
():
with
T
.
Kernel
(
1
,
threads
=
32
)
as
(
pid
,):
...
...
testing/python/issue/test_tilelang_issue_96.py
View file @
29051439
...
...
@@ -5,18 +5,16 @@ import torch
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
,
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
,
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
testing/python/issue/test_tilelang_issue_merge_if.py
View file @
29051439
...
...
@@ -6,7 +6,6 @@ import tilelang.language as T
def
merge_if_test
():
@
T
.
prim_func
def
main
():
A
=
T
.
alloc_fragment
((
1
,),
"float16"
)
...
...
testing/python/jit/test_tilelang_jit_callback.py
View file @
29051439
...
...
@@ -29,9 +29,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -141,9 +141,9 @@ def matmu_jit_kernel(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def
ref_program
(
A
,
B
):
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
...
...
testing/python/jit/test_tilelang_jit_gemm.py
View file @
29051439
...
...
@@ -31,9 +31,9 @@ def matmul_kernel_jit(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -96,6 +96,7 @@ def run_gemm_kernel_jit(
def
ref_program
(
A
,
B
):
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
...
...
testing/python/jit/test_tilelang_jit_gemm_cython.py
View file @
29051439
...
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -138,9 +138,9 @@ def matmu_jit_kernel(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def
ref_program
(
A
,
B
):
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
out_dtype
)
return
C
...
...
@@ -235,19 +236,9 @@ def test_gemm_jit_kernel():
)
def
run_cython_kernel_do_bench
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
def
run_cython_kernel_do_bench
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
M
,
N
,
...
...
@@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M,
def
test_cython_kernel_do_bench
():
run_cython_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
run_cython_kernel_multi_stream
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
run_cython_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
run_cython_kernel_multi_stream
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
M
,
N
,
...
...
@@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M,
def
test_cython_kernel_multi_stream
():
run_cython_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
run_cython_dynamic_shape
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
run_cython_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
run_cython_dynamic_shape
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
M
,
N
,
...
...
@@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M,
matmul_kernel
(
tensor_a
,
tensor_b
,
tensor_c
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_cython_dynamic_shape
():
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
run_cython_dynamic_shape_with_out_idx
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
run_cython_dynamic_shape_with_out_idx
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
M
,
N
,
...
...
@@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M,
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_cython_dynamic_shape_with_out_idx
():
run_cython_dynamic_shape_with_out_idx
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_cython_dynamic_shape_with_out_idx
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
matmul_int_variable
(
...
...
@@ -498,10 +449,10 @@ def matmul_int_variable(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
offset
:
T
.
int32
,
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
offset
:
T
.
int32
,
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -525,10 +476,10 @@ def matmul_int_variable(
return
main
def
run_matmul_int_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
program
=
matmul_int_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
)
def
run_matmul_int_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
program
=
matmul_int_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
)
matmul_kernel
=
tilelang
.
compile
(
program
,
execution_backend
=
"cython"
,
out_idx
=
2
)
in_dtype
=
map_torch_type
(
in_dtype
)
...
...
@@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B
def
test_matmul_int_variable
():
run_matmul_int_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
0
,
128
)
run_matmul_int_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
0
,
128
)
def
matmul_float_variable
(
...
...
@@ -570,10 +520,10 @@ def matmul_float_variable(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
offset
:
T
.
float32
,
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
offset
:
T
.
float32
,
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -597,10 +547,10 @@ def matmul_float_variable(
return
main
def
run_matmul_float_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
program
=
matmul_float_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
)
def
run_matmul_float_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
program
=
matmul_float_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
)
matmul_kernel
=
tilelang
.
compile
(
program
,
execution_backend
=
"cython"
,
out_idx
=
2
)
in_dtype
=
map_torch_type
(
in_dtype
)
...
...
@@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans
def
test_matmul_float_variable
():
run_matmul_float_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
0
,
128
)
run_matmul_float_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
0
,
128
)
if
__name__
==
"__main__"
:
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment