Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
228 additions
and
344 deletions
+228
-344
testing/python/carver/test_tilelang_carver_recommend_hints.py
...ing/python/carver/test_tilelang_carver_recommend_hints.py
+5
-12
testing/python/components/test_storage_rewrite_detect_inplace.py
.../python/components/test_storage_rewrite_detect_inplace.py
+2
-1
testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py
...nts/test_tilelang_pass_config_disable_warp_specialized.py
+5
-4
testing/python/cpu/test_tilelang_cpu_gemm.py
testing/python/cpu/test_tilelang_cpu_gemm.py
+6
-8
testing/python/debug/test_device_assert.py
testing/python/debug/test_device_assert.py
+0
-2
testing/python/debug/test_tilelang_debug_print.py
testing/python/debug/test_tilelang_debug_print.py
+18
-19
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
+27
-54
testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py
...ng/python/dynamic/test_tilelang_dynamic_symbolic_bench.py
+21
-33
testing/python/fastmath/test_mathops_fastmath.py
testing/python/fastmath/test_mathops_fastmath.py
+32
-49
testing/python/issue/test_tilelang_issue_1001.py
testing/python/issue/test_tilelang_issue_1001.py
+7
-6
testing/python/issue/test_tilelang_issue_1008.py
testing/python/issue/test_tilelang_issue_1008.py
+12
-10
testing/python/issue/test_tilelang_issue_1115.py
testing/python/issue/test_tilelang_issue_1115.py
+11
-13
testing/python/issue/test_tilelang_issue_1198.py
testing/python/issue/test_tilelang_issue_1198.py
+9
-5
testing/python/issue/test_tilelang_issue_814.py
testing/python/issue/test_tilelang_issue_814.py
+2
-3
testing/python/issue/test_tilelang_issue_830.py
testing/python/issue/test_tilelang_issue_830.py
+0
-2
testing/python/issue/test_tilelang_issue_96.py
testing/python/issue/test_tilelang_issue_96.py
+7
-9
testing/python/issue/test_tilelang_issue_merge_if.py
testing/python/issue/test_tilelang_issue_merge_if.py
+0
-1
testing/python/jit/test_tilelang_jit_callback.py
testing/python/jit/test_tilelang_jit_callback.py
+7
-6
testing/python/jit/test_tilelang_jit_gemm.py
testing/python/jit/test_tilelang_jit_gemm.py
+4
-3
testing/python/jit/test_tilelang_jit_gemm_cython.py
testing/python/jit/test_tilelang_jit_gemm_cython.py
+53
-104
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/carver/test_tilelang_carver_recommend_hints.py
View file @
29051439
...
@@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch
...
@@ -4,10 +4,7 @@ from tilelang.carver.arch import auto_infer_current_arch
from
typing
import
List
from
typing
import
List
def
run_general_reduction_recommend_hints
(
structure
:
str
=
"SSR"
,
def
run_general_reduction_recommend_hints
(
structure
:
str
=
"SSR"
,
shape
:
List
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
shape
:
List
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
arch
=
auto_infer_current_arch
()
arch
=
auto_infer_current_arch
()
carve_template
=
carver
.
GeneralReductionTemplate
(
carve_template
=
carver
.
GeneralReductionTemplate
(
structure
=
structure
,
structure
=
structure
,
...
@@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints():
...
@@ -28,9 +25,7 @@ def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints
(
"SRS"
,
[
1024
,
1024
,
1024
],
"float16"
)
run_general_reduction_recommend_hints
(
"SRS"
,
[
1024
,
1024
,
1024
],
"float16"
)
def
run_elementwise_recommend_hints
(
shape
:
List
[
int
]
=
None
,
def
run_elementwise_recommend_hints
(
shape
:
List
[
int
]
=
None
,
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
dtype
:
str
=
"float16"
,
topk
:
int
=
20
):
arch
=
auto_infer_current_arch
()
arch
=
auto_infer_current_arch
()
carve_template
=
carver
.
ElementwiseTemplate
(
carve_template
=
carver
.
ElementwiseTemplate
(
shape
=
shape
,
shape
=
shape
,
...
@@ -81,11 +76,9 @@ def test_matmul_recommend_hints():
...
@@ -81,11 +76,9 @@ def test_matmul_recommend_hints():
run_matmul_recommend_hints
(
1024
,
1024
,
1024
,
"float16"
,
"float32"
,
"float16"
)
run_matmul_recommend_hints
(
1024
,
1024
,
1024
,
"float16"
,
"float32"
,
"float16"
)
def
run_gemv_recommend_hints
(
N
:
int
=
1024
,
def
run_gemv_recommend_hints
(
K
:
int
=
1024
,
N
:
int
=
1024
,
K
:
int
=
1024
,
in_dtype
:
str
=
"float16"
,
out_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
in_dtype
:
str
=
"float16"
,
):
out_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
):
arch
=
auto_infer_current_arch
()
arch
=
auto_infer_current_arch
()
carve_template
=
carver
.
GEMVTemplate
(
carve_template
=
carver
.
GEMVTemplate
(
N
=
N
,
N
=
N
,
...
...
testing/python/components/test_storage_rewrite_detect_inplace.py
View file @
29051439
...
@@ -23,7 +23,8 @@ def _compile_kernel_without_inplace():
...
@@ -23,7 +23,8 @@ def _compile_kernel_without_inplace():
@
tilelang
.
jit
(
@
tilelang
.
jit
(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_STORAGE_REWRITE_DETECT_INPLACE
:
True
,
tilelang
.
PassConfigKey
.
TL_STORAGE_REWRITE_DETECT_INPLACE
:
True
,
},)
},
)
def
_compile_kernel_with_inplace
():
def
_compile_kernel_with_inplace
():
num_tokens
=
T
.
symbolic
(
"num_tokens"
)
num_tokens
=
T
.
symbolic
(
"num_tokens"
)
...
...
testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py
View file @
29051439
...
@@ -26,9 +26,9 @@ def matmul(
...
@@ -26,9 +26,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -88,7 +88,8 @@ def run_gemm(
...
@@ -88,7 +88,8 @@ def run_gemm(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
disable_warp_specialized
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
disable_warp_specialized
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
...
testing/python/cpu/test_tilelang_cpu_gemm.py
View file @
29051439
...
@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
...
@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@
T
.
prim_func
@
T
.
prim_func
def
matmul
(
def
matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
is_cpu
=
True
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
is_cpu
=
True
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_local
((
block_M
,
block_K
),
dtype
)
A_local
=
T
.
alloc_local
((
block_M
,
block_K
),
dtype
)
...
@@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
...
@@ -31,7 +31,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# )
# )
for
ko
in
T
.
Pipelined
(
K
//
block_K
,
num_stages
=
num_stages
):
for
ko
in
T
.
Pipelined
(
K
//
block_K
,
num_stages
=
num_stages
):
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_local
)
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_local
)
# Or Copy with Parallel
# Or Copy with Parallel
...
@@ -62,14 +61,13 @@ def test_matmul_codegen():
...
@@ -62,14 +61,13 @@ def test_matmul_codegen():
def
test_matmul_compile
():
def
test_matmul_compile
():
def
matmul_jit_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_jit_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
# a simple kernel just for jit test
# a simple kernel just for jit test
@
T
.
prim_func
@
T
.
prim_func
def
matmul
(
def
matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
is_cpu
=
True
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
is_cpu
=
True
)
as
(
bx
,
by
):
A_local
=
T
.
alloc_local
((
block_M
,
block_K
),
dtype
)
A_local
=
T
.
alloc_local
((
block_M
,
block_K
),
dtype
)
...
...
testing/python/debug/test_device_assert.py
View file @
29051439
...
@@ -7,7 +7,6 @@ import tilelang.language as T
...
@@ -7,7 +7,6 @@ import tilelang.language as T
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# Please run manually when you want to verify that device_assert actually traps on GPU.
# Please run manually when you want to verify that device_assert actually traps on GPU.
def
_manual_device_assert_triggered
():
def
_manual_device_assert_triggered
():
@
T
.
prim_func
@
T
.
prim_func
def
program
():
def
program
():
with
T
.
Kernel
(
threads
=
128
):
with
T
.
Kernel
(
threads
=
128
):
...
@@ -20,7 +19,6 @@ def _manual_device_assert_triggered():
...
@@ -20,7 +19,6 @@ def _manual_device_assert_triggered():
def
test_device_assert_no_trigger
():
def
test_device_assert_no_trigger
():
@
T
.
prim_func
@
T
.
prim_func
def
program
():
def
program
():
with
T
.
Kernel
(
threads
=
128
):
with
T
.
Kernel
(
threads
=
128
):
...
...
testing/python/debug/test_tilelang_debug_print.py
View file @
29051439
...
@@ -6,7 +6,6 @@ import tilelang.language as T
...
@@ -6,7 +6,6 @@ import tilelang.language as T
def
debug_print_buffer
(
M
=
16
,
N
=
16
,
dtype
=
"float16"
):
def
debug_print_buffer
(
M
=
16
,
N
=
16
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
program
(
Q
:
T
.
Tensor
((
M
,
N
),
dtype
)):
def
program
(
Q
:
T
.
Tensor
((
M
,
N
),
dtype
)):
with
T
.
Kernel
(
4
,
4
,
2
,
threads
=
128
*
2
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
4
,
4
,
2
,
threads
=
128
*
2
)
as
(
bx
,
by
,
bz
):
...
@@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
...
@@ -19,24 +18,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def
test_debug_print_buffer
():
def
test_debug_print_buffer
():
debug_print_buffer
(
dtype
=
'
bool
'
)
debug_print_buffer
(
dtype
=
"
bool
"
)
debug_print_buffer
(
dtype
=
'
int8
'
)
debug_print_buffer
(
dtype
=
"
int8
"
)
debug_print_buffer
(
dtype
=
'
int16
'
)
debug_print_buffer
(
dtype
=
"
int16
"
)
debug_print_buffer
(
dtype
=
'
int32
'
)
debug_print_buffer
(
dtype
=
"
int32
"
)
debug_print_buffer
(
dtype
=
'
int64
'
)
debug_print_buffer
(
dtype
=
"
int64
"
)
debug_print_buffer
(
dtype
=
'
uint8
'
)
debug_print_buffer
(
dtype
=
"
uint8
"
)
debug_print_buffer
(
dtype
=
'
uint16
'
)
debug_print_buffer
(
dtype
=
"
uint16
"
)
debug_print_buffer
(
dtype
=
'
uint32
'
)
debug_print_buffer
(
dtype
=
"
uint32
"
)
debug_print_buffer
(
dtype
=
'
uint64
'
)
debug_print_buffer
(
dtype
=
"
uint64
"
)
debug_print_buffer
(
dtype
=
'
float16
'
)
debug_print_buffer
(
dtype
=
"
float16
"
)
debug_print_buffer
(
dtype
=
'
float32
'
)
debug_print_buffer
(
dtype
=
"
float32
"
)
debug_print_buffer
(
dtype
=
'
float64
'
)
debug_print_buffer
(
dtype
=
"
float64
"
)
debug_print_buffer
(
dtype
=
'
bfloat16
'
)
debug_print_buffer
(
dtype
=
"
bfloat16
"
)
debug_print_buffer
(
dtype
=
'
float8_e4m3
'
)
debug_print_buffer
(
dtype
=
"
float8_e4m3
"
)
debug_print_buffer
(
dtype
=
'
float8_e4m3fn
'
)
debug_print_buffer
(
dtype
=
"
float8_e4m3fn
"
)
debug_print_buffer
(
dtype
=
'
float8_e4m3fnuz
'
)
debug_print_buffer
(
dtype
=
"
float8_e4m3fnuz
"
)
debug_print_buffer
(
dtype
=
'
float8_e5m2
'
)
debug_print_buffer
(
dtype
=
"
float8_e5m2
"
)
debug_print_buffer
(
dtype
=
'
float8_e5m2fnuz
'
)
debug_print_buffer
(
dtype
=
"
float8_e5m2fnuz
"
)
def
debug_print_buffer_conditional
(
M
=
16
,
N
=
16
):
def
debug_print_buffer_conditional
(
M
=
16
,
N
=
16
):
...
...
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
View file @
29051439
...
@@ -5,7 +5,7 @@ import tilelang.testing
...
@@ -5,7 +5,7 @@ import tilelang.testing
from
tvm
import
DataType
from
tvm
import
DataType
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics.utils
import
get_swizzle_layout
from
tilelang.intrinsics.utils
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
)
from
tilelang.intrinsics.mma_macro_generator
import
TensorCoreIntrinEmitter
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
testing
.
set_random_seed
(
0
)
...
@@ -96,12 +96,11 @@ def tl_matmul_macro(
...
@@ -96,12 +96,11 @@ def tl_matmul_macro(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -109,10 +108,12 @@ def tl_matmul_macro(
...
@@ -109,10 +108,12 @@ def tl_matmul_macro(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -120,7 +121,6 @@ def tl_matmul_macro(
...
@@ -120,7 +121,6 @@ def tl_matmul_macro(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
@@ -130,7 +130,6 @@ def tl_matmul_macro(
...
@@ -130,7 +130,6 @@ def tl_matmul_macro(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
mma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
@@ -207,8 +206,7 @@ def tl_matmul_block(
...
@@ -207,8 +206,7 @@ def tl_matmul_block(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
@@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic(
...
@@ -306,8 +304,7 @@ def tl_matmul_block_all_dynamic(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
@@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
...
@@ -417,7 +414,7 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
)
)
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_DYNAMIC_TAIL_SPLIT
:
dynamic_alignment
!=
0
,
tilelang
.
PassConfigKey
.
TL_DISABLE_DYNAMIC_TAIL_SPLIT
:
dynamic_alignment
!=
0
,
tilelang
.
PassConfigKey
.
TL_DYNAMIC_ALIGNMENT
:
dynamic_alignment
tilelang
.
PassConfigKey
.
TL_DYNAMIC_ALIGNMENT
:
dynamic_alignment
,
}
}
if
M
%
64
==
0
or
N
%
64
==
0
or
K
%
64
!=
0
:
if
M
%
64
==
0
or
N
%
64
==
0
or
K
%
64
!=
0
:
# workaround for hopper tma lower pass
# workaround for hopper tma lower pass
...
@@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro():
...
@@ -462,55 +459,31 @@ def test_assert_tl_matmul_macro():
def
test_assert_tl_matmul_block
():
def
test_assert_tl_matmul_block
():
assert_tl_matmul_block_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
assert_tl_matmul_block_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
assert_tl_matmul_block_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
64
,
64
,
32
)
assert_tl_matmul_block_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
def
test_assert_tl_matmul_block_all_dynamic
():
def
test_assert_tl_matmul_block_all_dynamic
():
assert_tl_matmul_block_all_dynamic_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
assert_tl_matmul_block_all_dynamic_correctness
(
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
67
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
assert_tl_matmul_block_all_dynamic_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
"float16"
,
64
,
64
,
32
)
assert_tl_matmul_block_all_dynamic_correctness
(
36
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
)
def
test_assert_tl_matmul_block_all_dynamic_with_pass_config
():
def
test_assert_tl_matmul_block_all_dynamic_with_pass_config
():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
128
,
128
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
128
,
)
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
64
,
64
,
128
,
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
128
,
)
128
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
64
,
128
,
60
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
4
)
64
,
128
,
60
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
4
)
# Tail split is enabled with dynamic alignment 0
# Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config
(
64
,
128
,
64
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
0
)
64
,
128
,
64
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
64
,
64
,
32
,
dynamic_alignment
=
0
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py
View file @
29051439
...
@@ -25,10 +25,8 @@ def tl_matmul_block_static(
...
@@ -25,10 +25,8 @@ def tl_matmul_block_static(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m(
...
@@ -137,10 +135,8 @@ def tl_matmul_block_dynamic_m(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn(
...
@@ -247,10 +243,8 @@ def tl_matmul_block_dynamic_mn(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk(
...
@@ -357,10 +351,8 @@ def tl_matmul_block_dynamic_mnk(
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
num_threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk(
...
@@ -445,8 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk(
def
run_assert_tl_matmul_block_static
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
def
run_assert_tl_matmul_block_static
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
assert_tl_matmul_block_static
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
False
,
False
,
"float16"
,
assert_tl_matmul_block_static
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
)
"float16"
,
"float32"
)
def
run_assert_tl_matmul_block_dynamic_m
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
def
run_assert_tl_matmul_block_dynamic_m
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
...
@@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
...
@@ -462,10 +453,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float32"
,
"float32"
,
pass_configs
=
{
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
8
},
"tl.disable_dynamic_tail_split"
:
True
,
)
"tl.dynamic_alignment"
:
8
})
assert_tl_matmul_block_dynamic_m
(
assert_tl_matmul_block_dynamic_m
(
M
,
M
,
N
,
N
,
...
@@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
...
@@ -478,7 +467,8 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float32"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
},
)
def
run_assert_tl_matmul_block_dynamic_mn
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
def
run_assert_tl_matmul_block_dynamic_mn
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
...
@@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
...
@@ -494,10 +484,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float32"
,
"float32"
,
pass_configs
=
{
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
8
},
"tl.disable_dynamic_tail_split"
:
True
,
)
"tl.dynamic_alignment"
:
8
})
assert_tl_matmul_block_dynamic_mn
(
assert_tl_matmul_block_dynamic_mn
(
M
,
M
,
N
,
N
,
...
@@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
...
@@ -510,7 +498,8 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float32"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
},
)
def
run_assert_tl_matmul_block_dynamic_mnk
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
def
run_assert_tl_matmul_block_dynamic_mnk
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
):
...
@@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
...
@@ -526,10 +515,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float32"
,
"float32"
,
pass_configs
=
{
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
True
,
"tl.dynamic_alignment"
:
4
},
"tl.disable_dynamic_tail_split"
:
True
,
)
"tl.dynamic_alignment"
:
4
})
assert_tl_matmul_block_dynamic_mnk
(
assert_tl_matmul_block_dynamic_mnk
(
M
,
M
,
N
,
N
,
...
@@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
...
@@ -542,7 +529,8 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float16"
,
"float16"
,
"float16"
,
"float16"
,
"float32"
,
"float32"
,
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
})
pass_configs
=
{
"tl.disable_dynamic_tail_split"
:
False
},
)
def
test_all
():
def
test_all
():
...
...
testing/python/fastmath/test_mathops_fastmath.py
View file @
29051439
...
@@ -7,16 +7,16 @@ import re
...
@@ -7,16 +7,16 @@ import re
def
get_mathop_lines
(
source
,
mathop_name
):
def
get_mathop_lines
(
source
,
mathop_name
):
"""Extract lines containing the mathop from CUDA source for debugging"""
"""Extract lines containing the mathop from CUDA source for debugging"""
lines
=
source
.
split
(
'
\n
'
)
lines
=
source
.
split
(
"
\n
"
)
relevant_lines
=
[]
relevant_lines
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
mathop_name
in
line
and
(
'('
in
line
):
if
mathop_name
in
line
and
(
"("
in
line
):
# Include some context
# Include some context
start
=
max
(
0
,
i
-
1
)
start
=
max
(
0
,
i
-
1
)
end
=
min
(
len
(
lines
),
i
+
2
)
end
=
min
(
len
(
lines
),
i
+
2
)
relevant_lines
.
extend
([
f
"
{
j
}
:
{
lines
[
j
]
}
"
for
j
in
range
(
start
,
end
)])
relevant_lines
.
extend
([
f
"
{
j
}
:
{
lines
[
j
]
}
"
for
j
in
range
(
start
,
end
)])
relevant_lines
.
append
(
"---"
)
relevant_lines
.
append
(
"---"
)
return
'
\n
'
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
return
"
\n
"
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
def
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
):
def
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
):
...
@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
...
@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches
=
re
.
findall
(
fastmath_pattern
,
source
)
fastmath_matches
=
re
.
findall
(
fastmath_pattern
,
source
)
non_fastmath_matches
=
re
.
findall
(
non_fastmath_pattern
,
source
)
non_fastmath_matches
=
re
.
findall
(
non_fastmath_pattern
,
source
)
print
(
print
(
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
if
len
(
fastmath_matches
)
>
0
:
if
len
(
fastmath_matches
)
>
0
:
print
(
f
"Fastmath calls found:
{
fastmath_matches
}
"
)
print
(
f
"Fastmath calls found:
{
fastmath_matches
}
"
)
if
len
(
non_fastmath_matches
)
>
0
:
if
len
(
non_fastmath_matches
)
>
0
:
...
@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
...
@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
)
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
)
def
run_single_arg_mathop_test
(
mathop_name
,
def
run_single_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
"""
Test single-argument mathops.
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...
@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
...
@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
bx
*
block_N
+
j
])
# Test with FAST_MATH disabled
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
kernel_no_fastmath
=
tilelang
.
compile
(
...
@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
...
@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
...
@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
...
@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print
(
f
"✓
{
mathop_name
}
compilation and execution test passed"
)
print
(
f
"✓
{
mathop_name
}
compilation and execution test passed"
)
def
run_two_arg_mathop_test
(
mathop_name
,
def
run_two_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
"""
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
)
)
# Test with FAST_MATH disabled
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
kernel_no_fastmath
=
tilelang
.
compile
(
...
@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
...
@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
# Test with FAST_MATH enabled
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
kernel_fastmath
=
tilelang
.
compile
(
...
@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
...
@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
...
@@ -171,8 +159,8 @@ def run_abs_test():
...
@@ -171,8 +159,8 @@ def run_abs_test():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -184,7 +172,8 @@ def run_abs_test():
...
@@ -184,7 +172,8 @@ def run_abs_test():
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source
=
kernel
.
get_kernel_source
()
source
=
kernel
.
get_kernel_source
()
print
(
"
\n
=== Testing abs (maps to fabs) ==="
)
print
(
"
\n
=== Testing abs (maps to fabs) ==="
)
...
@@ -199,26 +188,19 @@ def run_abs_test():
...
@@ -199,26 +188,19 @@ def run_abs_test():
print
(
"✓ abs numerical test passed"
)
print
(
"✓ abs numerical test passed"
)
def
run_fastmath_mathop_test
(
mathop_name
,
def
run_fastmath_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
"""
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
bx
*
block_N
+
j
])
# Test with FAST_MATH enabled
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
kernel_fastmath
=
tilelang
.
compile
(
...
@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
...
@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
print
(
f
"
\n
=== Testing
{
mathop_name
}
(fastmath version) ==="
)
print
(
f
"
\n
=== Testing
{
mathop_name
}
(fastmath version) ==="
)
print
(
"FAST_MATH=True:"
)
print
(
"FAST_MATH=True:"
)
# Strip the __ prefix for checking in the CUDA source
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name
=
mathop_name
.
lstrip
(
'_'
)
cuda_mathop_name
=
mathop_name
.
lstrip
(
"_"
)
check_fastmath_usage
(
source_fastmath
,
cuda_mathop_name
,
expect_fastmath
=
True
)
check_fastmath_usage
(
source_fastmath
,
cuda_mathop_name
,
expect_fastmath
=
True
)
# Test numerical correctness
# Test numerical correctness
...
...
testing/python/issue/test_tilelang_issue_1001.py
View file @
29051439
...
@@ -8,14 +8,15 @@ from tilelang import language as T
...
@@ -8,14 +8,15 @@ from tilelang import language as T
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
},
)
def
_cumsum_view_infer_layout
(
hidden
):
def
_cumsum_view_infer_layout
(
hidden
):
num_tokens
=
T
.
dynamic
(
'
num_tokens
'
)
num_tokens
=
T
.
dynamic
(
"
num_tokens
"
)
@
T
.
prim_func
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,
hidden
),
'
float
'
]):
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,
hidden
),
"
float
"
]):
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
pid
:
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
pid
:
smem
=
T
.
alloc_shared
((
hidden
,),
dtype
=
'
float
'
)
smem
=
T
.
alloc_shared
((
hidden
,),
dtype
=
"
float
"
)
T
.
copy
(
x
[
pid
,
:],
smem
)
T
.
copy
(
x
[
pid
,
:],
smem
)
T
.
cumsum
(
T
.
view
(
smem
,
(
1
,
hidden
)),
dim
=
1
)
T
.
cumsum
(
T
.
view
(
smem
,
(
1
,
hidden
)),
dim
=
1
)
...
@@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden):
...
@@ -24,10 +25,10 @@ def _cumsum_view_infer_layout(hidden):
def
test_cumsum_view_infer_layout
():
def
test_cumsum_view_infer_layout
():
hidden
=
128
hidden
=
128
x
=
torch
.
randn
(
1
,
hidden
,
device
=
'
cuda
'
,
dtype
=
torch
.
float
)
x
=
torch
.
randn
(
1
,
hidden
,
device
=
"
cuda
"
,
dtype
=
torch
.
float
)
kernel
=
_cumsum_view_infer_layout
(
hidden
)
kernel
=
_cumsum_view_infer_layout
(
hidden
)
kernel
(
x
)
kernel
(
x
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1008.py
View file @
29051439
...
@@ -8,12 +8,13 @@ from tilelang import language as T
...
@@ -8,12 +8,13 @@ from tilelang import language as T
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
},
)
def
_fill_with_static_region_kernel
():
def
_fill_with_static_region_kernel
():
num_tokens
=
T
.
symbolic
(
'
num_tokens
'
)
num_tokens
=
T
.
symbolic
(
"
num_tokens
"
)
@
T
.
prim_func
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
'
int64
'
]):
# noqa: F821
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
"
int64
"
]):
# noqa: F821
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
T
.
fill
(
x
[
0
:
128
],
0
)
T
.
fill
(
x
[
0
:
128
],
0
)
...
@@ -24,14 +25,15 @@ def _fill_with_static_region_kernel():
...
@@ -24,14 +25,15 @@ def _fill_with_static_region_kernel():
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
},
)
def
_fill_with_dynamic_region_kernel
():
def
_fill_with_dynamic_region_kernel
():
num_tokens
=
T
.
symbolic
(
'
num_tokens
'
)
num_tokens
=
T
.
symbolic
(
"
num_tokens
"
)
@
T
.
prim_func
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
'
int64
'
]):
# noqa: F821
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
"
int64
"
]):
# noqa: F821
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
a
,
b
=
T
.
alloc_var
(
'
int
'
),
T
.
alloc_var
(
'
int
'
)
a
,
b
=
T
.
alloc_var
(
"
int
"
),
T
.
alloc_var
(
"
int
"
)
T
.
fill
(
x
[
a
:
b
],
0
)
T
.
fill
(
x
[
a
:
b
],
0
)
return
buggy_kernel
return
buggy_kernel
...
@@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel():
...
@@ -39,15 +41,15 @@ def _fill_with_dynamic_region_kernel():
def
test_fill_with_static_region_kernel
():
def
test_fill_with_static_region_kernel
():
kernel
=
_fill_with_static_region_kernel
()
kernel
=
_fill_with_static_region_kernel
()
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
'
cuda
'
)
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
"
cuda
"
)
kernel
(
x
)
kernel
(
x
)
def
test_fill_with_dynamic_region_kernel
():
def
test_fill_with_dynamic_region_kernel
():
kernel
=
_fill_with_dynamic_region_kernel
()
kernel
=
_fill_with_dynamic_region_kernel
()
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
'
cuda
'
)
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
"
cuda
"
)
kernel
(
x
)
kernel
(
x
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1115.py
View file @
29051439
...
@@ -4,25 +4,23 @@ import tilelang.language as T
...
@@ -4,25 +4,23 @@ import tilelang.language as T
def
test_int64_address
():
def
test_int64_address
():
@
tilelang
.
jit
@
tilelang
.
jit
def
set_cache_kernel
(
def
set_cache_kernel
(
S
,
S
,
D
,
D
,
pos_ty
=
'
int64
'
,
pos_ty
=
"
int64
"
,
dtype
=
"float32"
,
dtype
=
"float32"
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
pos
:
T
pos
:
T
.
Tensor
(
.
Tensor
(
[
[
S
,
S
,
],
pos_ty
],
pos_ty
,
),
# type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
),
# type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
value
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
cache
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
cache
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
S
,
threads
=
128
)
as
bx
:
with
T
.
Kernel
(
S
,
threads
=
128
)
as
bx
:
slot
=
pos
[
bx
]
slot
=
pos
[
bx
]
...
@@ -34,11 +32,11 @@ def test_int64_address():
...
@@ -34,11 +32,11 @@ def test_int64_address():
D
=
2
D
=
2
S
=
10
S
=
10
cache
=
torch
.
rand
((
S
,
D
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
cache
=
torch
.
rand
((
S
,
D
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
value
=
torch
.
rand
((
S
,
D
),
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
value
=
torch
.
rand
((
S
,
D
),
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
pos_int64
=
torch
.
arange
(
S
,
device
=
'
cuda
'
,
dtype
=
torch
.
int64
)
pos_int64
=
torch
.
arange
(
S
,
device
=
"
cuda
"
,
dtype
=
torch
.
int64
)
pos_int32
=
torch
.
arange
(
S
,
device
=
'
cuda
'
,
dtype
=
torch
.
int32
)
pos_int32
=
torch
.
arange
(
S
,
device
=
"
cuda
"
,
dtype
=
torch
.
int32
)
kernel_int64
=
set_cache_kernel
(
S
,
D
,
'
int64
'
)
kernel_int64
=
set_cache_kernel
(
S
,
D
,
"
int64
"
)
kernel_int32
=
set_cache_kernel
(
S
,
D
,
'
int32
'
)
kernel_int32
=
set_cache_kernel
(
S
,
D
,
"
int32
"
)
kernel_int64
(
pos_int64
,
value
,
cache
)
kernel_int64
(
pos_int64
,
value
,
cache
)
torch
.
testing
.
assert_close
(
cache
,
value
)
torch
.
testing
.
assert_close
(
cache
,
value
)
kernel_int32
(
pos_int32
,
value
,
cache
)
kernel_int32
(
pos_int32
,
value
,
cache
)
...
...
testing/python/issue/test_tilelang_issue_1198.py
View file @
29051439
...
@@ -3,13 +3,17 @@ import tilelang.language as T
...
@@ -3,13 +3,17 @@ import tilelang.language as T
def
test_issue_1198
():
def
test_issue_1198
():
@
T
.
prim_func
@
T
.
prim_func
def
foo
(
x
:
T
.
Buffer
([
def
foo
(
32
,
x
:
T
.
Buffer
(
],
"int32"
)):
[
32
,
],
"int32"
,
),
):
pass
pass
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_814.py
View file @
29051439
...
@@ -6,11 +6,10 @@ import torch
...
@@ -6,11 +6,10 @@ import torch
@
tilelang
.
jit
@
tilelang
.
jit
def
_tmp_var_kernel
(
N
,
block_N
,
dtype
=
"float"
):
def
_tmp_var_kernel
(
N
,
block_N
,
dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
kernel
(
def
kernel
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
128
)
as
bx
:
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
128
)
as
bx
:
for
i
in
T
.
Parallel
(
block_N
):
for
i
in
T
.
Parallel
(
block_N
):
...
...
testing/python/issue/test_tilelang_issue_830.py
View file @
29051439
...
@@ -8,7 +8,6 @@ import tilelang.language as T
...
@@ -8,7 +8,6 @@ import tilelang.language as T
@
tilelang
.
jit
@
tilelang
.
jit
def
_empty_kernel
():
def
_empty_kernel
():
@
T
.
prim_func
@
T
.
prim_func
def
empty_kernel
():
def
empty_kernel
():
with
T
.
Kernel
(
1
,
threads
=
32
)
as
thread_idx
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
thread_idx
:
...
@@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel():
...
@@ -51,7 +50,6 @@ def test_empty_with_dead_code_kernel():
@
tilelang
.
jit
@
tilelang
.
jit
def
_empty_kernel_with_binding_variants
(
use_tuple_binding
:
bool
=
False
):
def
_empty_kernel_with_binding_variants
(
use_tuple_binding
:
bool
=
False
):
@
T
.
prim_func
@
T
.
prim_func
def
kernel_with_tuple_kernel_binding
():
def
kernel_with_tuple_kernel_binding
():
with
T
.
Kernel
(
1
,
threads
=
32
)
as
(
pid
,):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
(
pid
,):
...
...
testing/python/issue/test_tilelang_issue_96.py
View file @
29051439
...
@@ -5,18 +5,16 @@ import torch
...
@@ -5,18 +5,16 @@ import torch
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
bx
,
by
,
by
,
):
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
testing/python/issue/test_tilelang_issue_merge_if.py
View file @
29051439
...
@@ -6,7 +6,6 @@ import tilelang.language as T
...
@@ -6,7 +6,6 @@ import tilelang.language as T
def
merge_if_test
():
def
merge_if_test
():
@
T
.
prim_func
@
T
.
prim_func
def
main
():
def
main
():
A
=
T
.
alloc_fragment
((
1
,),
"float16"
)
A
=
T
.
alloc_fragment
((
1
,),
"float16"
)
...
...
testing/python/jit/test_tilelang_jit_callback.py
View file @
29051439
...
@@ -29,9 +29,9 @@ def matmul(
...
@@ -29,9 +29,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -141,9 +141,9 @@ def matmu_jit_kernel(
...
@@ -141,9 +141,9 @@ def matmu_jit_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
...
@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
import
torch
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
return
C
...
...
testing/python/jit/test_tilelang_jit_gemm.py
View file @
29051439
...
@@ -31,9 +31,9 @@ def matmul_kernel_jit(
...
@@ -31,9 +31,9 @@ def matmul_kernel_jit(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -96,6 +96,7 @@ def run_gemm_kernel_jit(
...
@@ -96,6 +96,7 @@ def run_gemm_kernel_jit(
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
import
torch
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
return
C
...
...
testing/python/jit/test_tilelang_jit_gemm_cython.py
View file @
29051439
...
@@ -28,9 +28,9 @@ def matmul(
...
@@ -28,9 +28,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -138,9 +138,9 @@ def matmu_jit_kernel(
...
@@ -138,9 +138,9 @@ def matmu_jit_kernel(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
...
@@ -208,6 +208,7 @@ def run_gemm_jit_kernel(
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
import
torch
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
out_dtype
)
C
=
C
.
to
(
out_dtype
)
return
C
return
C
...
@@ -235,19 +236,9 @@ def test_gemm_jit_kernel():
...
@@ -235,19 +236,9 @@ def test_gemm_jit_kernel():
)
)
def
run_cython_kernel_do_bench
(
M
,
def
run_cython_kernel_do_bench
(
N
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
K
,
):
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M,
...
@@ -287,23 +278,12 @@ def run_cython_kernel_do_bench(M,
def
test_cython_kernel_do_bench
():
def
test_cython_kernel_do_bench
():
run_cython_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
run_cython_kernel_do_bench
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
256
,
32
,
2
)
def
run_cython_kernel_multi_stream
(
def
run_cython_kernel_multi_stream
(
M
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
N
,
):
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M,
...
@@ -342,23 +322,12 @@ def run_cython_kernel_multi_stream(M,
def
test_cython_kernel_multi_stream
():
def
test_cython_kernel_multi_stream
():
run_cython_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
run_cython_kernel_multi_stream
(
512
,
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
128
,
256
,
32
,
2
)
def
run_cython_dynamic_shape
(
def
run_cython_dynamic_shape
(
M
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
N
,
):
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M,
...
@@ -398,36 +367,20 @@ def run_cython_dynamic_shape(M,
matmul_kernel
(
tensor_a
,
tensor_b
,
tensor_c
)
matmul_kernel
(
tensor_a
,
tensor_b
,
tensor_c
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tilelang
.
testing
.
torch_assert_close
(
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_cython_dynamic_shape
():
def
test_cython_dynamic_shape
():
run_cython_dynamic_shape
(
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
run_cython_dynamic_shape
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
256
,
32
,
2
)
run_cython_dynamic_shape
(
def
run_cython_dynamic_shape_with_out_idx
(
T
.
dynamic
(
"m"
),
T
.
dynamic
(
"n"
),
T
.
dynamic
(
"k"
),
False
,
False
,
"float16"
,
"float16"
,
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
"float16"
,
128
,
256
,
32
,
2
)
):
def
run_cython_dynamic_shape_with_out_idx
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
128
):
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M,
...
@@ -467,13 +420,11 @@ def run_cython_dynamic_shape_with_out_idx(M,
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tensor_ref_c
=
torch
.
matmul
(
tensor_a
.
to
(
torch
.
float
),
tensor_b
.
to
(
torch
.
float
)).
to
(
out_dtype
)
tilelang
.
testing
.
torch_assert_close
(
tilelang
.
testing
.
torch_assert_close
(
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
tensor_c
,
tensor_ref_c
,
atol
=
1e-2
,
rtol
=
1e-2
,
max_mismatched_ratio
=
0.05
)
def
test_cython_dynamic_shape_with_out_idx
():
def
test_cython_dynamic_shape_with_out_idx
():
run_cython_dynamic_shape_with_out_idx
(
run_cython_dynamic_shape_with_out_idx
(
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
T
.
dynamic
(
"m"
),
1024
,
768
,
False
,
False
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
)
def
matmul_int_variable
(
def
matmul_int_variable
(
...
@@ -498,10 +449,10 @@ def matmul_int_variable(
...
@@ -498,10 +449,10 @@ def matmul_int_variable(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
offset
:
T
.
int32
,
offset
:
T
.
int32
,
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -525,10 +476,10 @@ def matmul_int_variable(
...
@@ -525,10 +476,10 @@ def matmul_int_variable(
return
main
return
main
def
run_matmul_int_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
def
run_matmul_int_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
program
=
matmul_int_variable
(
program
=
matmul_int_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
out_dtype
,
dtypeAccum
,
num_stages
,
threads
)
)
matmul_kernel
=
tilelang
.
compile
(
program
,
execution_backend
=
"cython"
,
out_idx
=
2
)
matmul_kernel
=
tilelang
.
compile
(
program
,
execution_backend
=
"cython"
,
out_idx
=
2
)
in_dtype
=
map_torch_type
(
in_dtype
)
in_dtype
=
map_torch_type
(
in_dtype
)
...
@@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B
...
@@ -544,8 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B
def
test_matmul_int_variable
():
def
test_matmul_int_variable
():
run_matmul_int_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
run_matmul_int_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
0
,
128
)
"float32"
,
0
,
128
)
def
matmul_float_variable
(
def
matmul_float_variable
(
...
@@ -570,10 +520,10 @@ def matmul_float_variable(
...
@@ -570,10 +520,10 @@ def matmul_float_variable(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
offset
:
T
.
float32
,
offset
:
T
.
float32
,
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -597,10 +547,10 @@ def matmul_float_variable(
...
@@ -597,10 +547,10 @@ def matmul_float_variable(
return
main
return
main
def
run_matmul_float_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
def
run_matmul_float_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
out_dtype
,
dtypeAccum
,
num_stages
,
threads
):
program
=
matmul_float_variable
(
program
=
matmul_float_variable
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
threads
out_dtype
,
dtypeAccum
,
num_stages
,
threads
)
)
matmul_kernel
=
tilelang
.
compile
(
program
,
execution_backend
=
"cython"
,
out_idx
=
2
)
matmul_kernel
=
tilelang
.
compile
(
program
,
execution_backend
=
"cython"
,
out_idx
=
2
)
in_dtype
=
map_torch_type
(
in_dtype
)
in_dtype
=
map_torch_type
(
in_dtype
)
...
@@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans
...
@@ -616,8 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans
def
test_matmul_float_variable
():
def
test_matmul_float_variable
():
run_matmul_float_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
run_matmul_float_variable
(
1024
,
1024
,
1024
,
128
,
128
,
32
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
0
,
128
)
"float32"
,
0
,
128
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment