Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
208 additions
and
263 deletions
+208
-263
testing/python/language/test_tilelang_language_negative_index.py
.../python/language/test_tilelang_language_negative_index.py
+1
-2
testing/python/language/test_tilelang_language_parallel.py
testing/python/language/test_tilelang_language_parallel.py
+5
-7
testing/python/language/test_tilelang_language_pipeline.py
testing/python/language/test_tilelang_language_pipeline.py
+19
-29
testing/python/language/test_tilelang_language_ptr.py
testing/python/language/test_tilelang_language_ptr.py
+0
-1
testing/python/language/test_tilelang_language_reduce.py
testing/python/language/test_tilelang_language_reduce.py
+10
-10
testing/python/language/test_tilelang_language_reshape.py
testing/python/language/test_tilelang_language_reshape.py
+31
-23
testing/python/language/test_tilelang_language_ternary.py
testing/python/language/test_tilelang_language_ternary.py
+6
-6
testing/python/language/test_tilelang_language_tma_1d.py
testing/python/language/test_tilelang_language_tma_1d.py
+2
-4
testing/python/language/test_tilelang_language_unroll.py
testing/python/language/test_tilelang_language_unroll.py
+0
-2
testing/python/language/test_tilelang_language_var_init.py
testing/python/language/test_tilelang_language_var_init.py
+5
-7
testing/python/language/test_tilelang_language_vectorize.py
testing/python/language/test_tilelang_language_vectorize.py
+6
-10
testing/python/language/test_tilelang_language_vectorized_cast.py
...python/language/test_tilelang_language_vectorized_cast.py
+5
-6
testing/python/language/test_tilelang_language_view.py
testing/python/language/test_tilelang_language_view.py
+7
-5
testing/python/language/test_tilelang_language_warp_reduce.py
...ing/python/language/test_tilelang_language_warp_reduce.py
+10
-11
testing/python/layout/test_tilelang_layout_fused_replicate.py
...ing/python/layout/test_tilelang_layout_fused_replicate.py
+6
-7
testing/python/math/test_math_bitwise_reduce.py
testing/python/math/test_math_bitwise_reduce.py
+5
-6
testing/python/math/test_math_fast_math.py
testing/python/math/test_math_fast_math.py
+32
-49
testing/python/math/test_math_ieee_math.py
testing/python/math/test_math_ieee_math.py
+26
-33
testing/python/metal/test_metal_codegen.py
testing/python/metal/test_metal_codegen.py
+14
-15
testing/python/primitives/test_tilelang_primitives_mma.py
testing/python/primitives/test_tilelang_primitives_mma.py
+18
-30
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/language/test_tilelang_language_negative_index.py
View file @
29051439
...
@@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,)
...
@@ -31,8 +31,7 @@ def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,)
@
T
.
prim_func
@
T
.
prim_func
def
negative_index_symbolic_before
(
shift
:
T
.
int32
,
A
:
T
.
Buffer
((
16
,),
"float32"
),
def
negative_index_symbolic_before
(
shift
:
T
.
int32
,
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
16
):
for
i
in
T
.
serial
(
16
):
B
[
i
]
=
A
[
shift
+
i
]
B
[
i
]
=
A
[
shift
+
i
]
...
...
testing/python/language/test_tilelang_language_parallel.py
View file @
29051439
...
@@ -9,11 +9,10 @@ tilelang.testing.set_random_seed()
...
@@ -9,11 +9,10 @@ tilelang.testing.set_random_seed()
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
parallel_elementwise_static
(
length
=
256
,
dtype
=
"float32"
):
def
parallel_elementwise_static
(
length
=
256
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
):
for
i
in
T
.
Parallel
(
length
):
...
@@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"):
...
@@ -24,12 +23,11 @@ def parallel_elementwise_static(length=256, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
parallel_elementwise_dynamic
(
max_len
=
512
,
threads
=
256
,
dtype
=
"float32"
):
def
parallel_elementwise_dynamic
(
max_len
=
512
,
threads
=
256
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
max_len
,),
dtype
),
A
:
T
.
Tensor
((
max_len
,),
dtype
),
B
:
T
.
Tensor
((
max_len
,),
dtype
),
B
:
T
.
Tensor
((
max_len
,),
dtype
),
valid_len
:
T
.
int32
,
valid_len
:
T
.
int32
,
):
):
with
T
.
Kernel
(
1
,
threads
=
threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
threads
)
as
_
:
for
i
in
T
.
Parallel
(
max_len
):
for
i
in
T
.
Parallel
(
max_len
):
...
...
testing/python/language/test_tilelang_language_pipeline.py
View file @
29051439
...
@@ -27,9 +27,9 @@ def matmul(
...
@@ -27,9 +27,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -90,7 +90,8 @@ def run_gemm(
...
@@ -90,7 +90,8 @@ def run_gemm(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -103,8 +104,8 @@ def run_gemm(
...
@@ -103,8 +104,8 @@ def run_gemm(
if
in_dtype
==
"float32"
:
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
return
C
...
@@ -124,27 +125,19 @@ def test_pipeline_order_stage():
...
@@ -124,27 +125,19 @@ def test_pipeline_order_stage():
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
def
blocksparse_matmul
(
M
,
)
N
,
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
import
tilelang.language
as
T
import
tilelang.language
as
T
@
T
.
prim_func
@
T
.
prim_func
def
block_sparse_matmul
(
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages):
...
@@ -183,8 +176,7 @@ def run_blocksparse_matmul(num_stages):
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
kernel
=
blocksparse_matmul
(
kernel
=
blocksparse_matmul
(
M
,
N
,
K
,
block_M
=
block_M
,
block_N
=
block_N
,
block_K
=
block_K
,
num_stages
=
num_stages
)
M
,
N
,
K
,
block_M
=
block_M
,
block_N
=
block_N
,
block_K
=
block_K
,
num_stages
=
num_stages
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
# Create block mask with desired sparsity
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
...
@@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages):
...
@@ -200,12 +192,10 @@ def run_blocksparse_matmul(num_stages):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
(
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
].
to
(
torch
.
float32
)
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
))
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
return
ref_c
# Compute the reference result using the naive PyTorch implementation
# Compute the reference result using the naive PyTorch implementation
...
...
testing/python/language/test_tilelang_language_ptr.py
View file @
29051439
...
@@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type
...
@@ -7,7 +7,6 @@ from tilelang.utils import map_torch_type
def
matmul_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul_test
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
a_ptr
:
T
.
ptr
,
a_ptr
:
T
.
ptr
,
...
...
testing/python/language/test_tilelang_language_reduce.py
View file @
29051439
...
@@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb):
...
@@ -10,8 +10,8 @@ def _make_shared_reduce(M, N, dtype, reduce_cb):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
)
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
)
...
@@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"):
...
@@ -35,8 +35,8 @@ def reduce_max_test(M, N, dtype="float16"):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
@@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"):
...
@@ -54,8 +54,8 @@ def reduce_sum_test(M, N, dtype="float32"):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
@@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"):
...
@@ -145,8 +145,8 @@ def reduce_sum_test_clear(M, N, dtype="float32"):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
@@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"):
...
@@ -186,8 +186,8 @@ def reduce_max_test_clear(M, N, dtype="float16"):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
...
...
testing/python/language/test_tilelang_language_reshape.py
View file @
29051439
...
@@ -10,8 +10,8 @@ def reshape_test(N, M, dtype):
...
@@ -10,8 +10,8 @@ def reshape_test(N, M, dtype):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_reshaped
=
T
.
reshape
(
A
,
[
N
//
M
,
M
])
A_reshaped
=
T
.
reshape
(
A
,
[
N
//
M
,
M
])
...
@@ -30,7 +30,8 @@ def run_reshape(N, M, dtype):
...
@@ -30,7 +30,8 @@ def run_reshape(N, M, dtype):
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
def
ref_program
(
A
):
...
@@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype):
...
@@ -50,8 +51,8 @@ def reshape_test_smem_1d_2_2d(N, M, dtype):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
,),
dtype
)
A_shared
=
T
.
alloc_shared
((
N
,),
dtype
)
...
@@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
...
@@ -74,7 +75,8 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
def
ref_program
(
A
):
...
@@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
...
@@ -93,8 +95,8 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
)
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
)
...
@@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
...
@@ -117,7 +119,8 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
def
ref_program
(
A
):
...
@@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype):
...
@@ -136,8 +139,8 @@ def reshape_fragment_test(N, M, dtype):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
...
@@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype):
...
@@ -161,7 +164,8 @@ def run_reshape_fragment(N, M, dtype):
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
def
ref_program
(
A
):
...
@@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype):
...
@@ -181,15 +185,17 @@ def reshape_layout_transform_shared(N, M, dtype):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
A
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
A_shared
=
T
.
alloc_shared
((
N
//
M
,
M
),
dtype
,
scope
=
"shared"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_mma_swizzle_layout
(
A_shared
),
{
})
A_shared
:
make_mma_swizzle_layout
(
A_shared
),
}
)
T
.
copy
(
A
,
A_shared
)
T
.
copy
(
A
,
A_shared
)
A_shared_reshape
=
T
.
reshape
(
A_shared
,
[
N
])
A_shared_reshape
=
T
.
reshape
(
A_shared
,
[
N
])
T
.
copy
(
A_shared_reshape
,
B
)
T
.
copy
(
A_shared_reshape
,
B
)
...
@@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype):
...
@@ -205,7 +211,8 @@ def run_reshape_layout_transform_shared(N, M, dtype):
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
def
ref_program
(
A
):
...
@@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype):
...
@@ -224,8 +231,8 @@ def reduce_after_reshape_test(N, M, dtype):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_shared
=
T
.
alloc_shared
((
N
,),
dtype
,
scope
=
"shared"
)
A_shared
=
T
.
alloc_shared
((
N
,),
dtype
,
scope
=
"shared"
)
...
@@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype):
...
@@ -249,7 +256,8 @@ def run_reduce_after_reshape(N, M, dtype):
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
def
ref_program
(
A
):
...
@@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype):
...
@@ -268,8 +276,8 @@ def reshape_shape_mismatch_test(N, M, dtype):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
B
:
T
.
Tensor
((
N
//
M
,
M
),
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_reshaped
=
T
.
reshape
(
A
,
[
N
//
M
,
M
+
1
])
A_reshaped
=
T
.
reshape
(
A
,
[
N
//
M
,
M
+
1
])
...
...
testing/python/language/test_tilelang_language_ternary.py
View file @
29051439
...
@@ -4,19 +4,19 @@ import torch
...
@@ -4,19 +4,19 @@ import torch
import
tilelang.testing
import
tilelang.testing
@
tilelang
.
jit
(
out_idx
=
[
1
],)
@
tilelang
.
jit
(
out_idx
=
[
1
],
)
def
tilelang_ternary
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
def
tilelang_ternary
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
(
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
if
(
by
*
block_M
+
i
)
<
(
M
//
2
)
else
0
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
if
(
by
*
block_M
+
i
)
<
(
M
//
2
)
else
0
)
return
main
return
main
...
...
testing/python/language/test_tilelang_language_tma_1d.py
View file @
29051439
...
@@ -9,10 +9,8 @@ def ref_program(x, y):
...
@@ -9,10 +9,8 @@ def ref_program(x, y):
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
elementwise_add
(
M
,
N
,
block_M
,
block_N
,
in_dtype
,
out_dtype
,
threads
):
def
elementwise_add
(
M
,
N
,
block_M
,
block_N
,
in_dtype
,
out_dtype
,
threads
):
@
T
.
prim_func
@
T
.
prim_func
def
elem_add
(
A
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
B
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
C
:
T
.
Tensor
(
def
elem_add
(
A
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
B
:
T
.
Tensor
((
M
,
N
),
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
in_dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
in_dtype
)
B_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
in_dtype
)
B_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
in_dtype
)
...
@@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
...
@@ -21,7 +19,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
T
.
copy
(
A
[
by
*
block_M
,
bx
*
block_N
],
A_shared
)
T
.
copy
(
A
[
by
*
block_M
,
bx
*
block_N
],
A_shared
)
T
.
copy
(
B
[
by
*
block_M
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
B
[
by
*
block_M
,
bx
*
block_N
],
B_shared
)
for
(
local_y
,
local_x
)
in
T
.
Parallel
(
block_M
,
block_N
):
for
local_y
,
local_x
in
T
.
Parallel
(
block_M
,
block_N
):
C_local
[
local_y
,
local_x
]
=
A_shared
[
local_y
,
local_x
]
+
B_shared
[
local_y
,
local_x
]
C_local
[
local_y
,
local_x
]
=
A_shared
[
local_y
,
local_x
]
+
B_shared
[
local_y
,
local_x
]
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
...
...
testing/python/language/test_tilelang_language_unroll.py
View file @
29051439
...
@@ -4,7 +4,6 @@ from tilelang import language as T
...
@@ -4,7 +4,6 @@ from tilelang import language as T
def
test_unroll_with_step
():
def
test_unroll_with_step
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
...
@@ -19,7 +18,6 @@ def test_unroll_with_step():
...
@@ -19,7 +18,6 @@ def test_unroll_with_step():
def
test_unroll_with_unroll_factor
():
def
test_unroll_with_unroll_factor
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A_ptr
:
T
.
handle
):
def
main
(
A_ptr
:
T
.
handle
):
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
A
=
T
.
match_buffer
(
A_ptr
,
(
16
,
16
),
dtype
=
"float32"
,
align
=
16
)
...
...
testing/python/language/test_tilelang_language_var_init.py
View file @
29051439
...
@@ -4,17 +4,15 @@ import tilelang.testing
...
@@ -4,17 +4,15 @@ import tilelang.testing
def
test_var_assign
()
->
None
:
def
test_var_assign
()
->
None
:
@
tilelang
.
jit
(
out_idx
=-
1
)
@
tilelang
.
jit
(
out_idx
=-
1
)
def
jit_kernel
():
def
jit_kernel
():
@
T
.
prim_func
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
'
int32
'
)):
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
"
int32
"
)):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
a
=
T
.
alloc_var
(
'
int32
'
,
init
=
1
)
a
=
T
.
alloc_var
(
"
int32
"
,
init
=
1
)
b
=
T
.
alloc_var
(
'
int32
'
,
init
=
a
)
# b gets value of a
b
=
T
.
alloc_var
(
"
int32
"
,
init
=
a
)
# b gets value of a
a
=
2
a
=
2
d
=
T
.
alloc_var
(
'
int32
'
,
init
=
a
)
# c gets new value of a
d
=
T
.
alloc_var
(
"
int32
"
,
init
=
a
)
# c gets new value of a
A
[
0
]
=
b
A
[
0
]
=
b
A
[
1
]
=
d
A
[
1
]
=
d
...
@@ -28,5 +26,5 @@ def test_var_assign() -> None:
...
@@ -28,5 +26,5 @@ def test_var_assign() -> None:
assert
res
[
1
]
==
2
assert
res
[
1
]
==
2
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_vectorize.py
View file @
29051439
...
@@ -5,11 +5,10 @@ import tilelang.language as T
...
@@ -5,11 +5,10 @@ import tilelang.language as T
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_VECTORIZE_256
:
True
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_VECTORIZE_256
:
True
})
def
vectorize_test
(
N
,
M
,
stride_A
,
stride_B
):
def
vectorize_test
(
N
,
M
,
stride_A
,
stride_B
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_A
),
"float32"
],
# noqa: F821
A
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_A
),
"float32"
],
# noqa: F821
B
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_B
),
"float32"
],
# noqa: F821
B
:
T
.
StridedTensor
[(
N
,
M
),
(
1
,
stride_B
),
"float32"
],
# noqa: F821
):
):
with
T
.
Kernel
(
M
//
128
,
threads
=
128
)
as
(
bx
):
with
T
.
Kernel
(
M
//
128
,
threads
=
128
)
as
(
bx
):
tx
=
T
.
get_thread_binding
(
0
)
tx
=
T
.
get_thread_binding
(
0
)
...
@@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B):
...
@@ -39,9 +38,7 @@ def run_vectorize(N, M, stride_A, stride_B):
code
=
jit_kernel
.
get_kernel_source
()
code
=
jit_kernel
.
get_kernel_source
()
vectorize_size
=
1
vectorize_size
=
1
while
vectorize_size
<=
2
and
\
while
vectorize_size
<=
2
and
stride_A
%
(
vectorize_size
*
2
)
==
0
and
stride_B
%
(
vectorize_size
*
2
)
==
0
:
stride_A
%
(
vectorize_size
*
2
)
==
0
and
\
stride_B
%
(
vectorize_size
*
2
)
==
0
:
vectorize_size
*=
2
vectorize_size
*=
2
if
vectorize_size
==
4
:
if
vectorize_size
==
4
:
...
@@ -61,12 +58,11 @@ def test_vectorize():
...
@@ -61,12 +58,11 @@ def test_vectorize():
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_VECTORIZE_256
:
True
})
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_VECTORIZE_256
:
True
})
def
vectorize_test_invariant_index
(
N
,
M
,
K
):
def
vectorize_test_invariant_index
(
N
,
M
,
K
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
A
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
B
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
B
:
T
.
Tensor
[(
N
,
M
),
"float32"
],
# noqa: F821
C
:
T
.
Tensor
[(
N
,
M
//
K
),
"float32"
],
# noqa: F821
C
:
T
.
Tensor
[(
N
,
M
//
K
),
"float32"
],
# noqa: F821
):
):
with
T
.
Kernel
(
N
//
128
,
threads
=
128
)
as
(
bx
):
with
T
.
Kernel
(
N
//
128
,
threads
=
128
)
as
(
bx
):
tx
=
T
.
get_thread_binding
(
0
)
tx
=
T
.
get_thread_binding
(
0
)
...
...
testing/python/language/test_tilelang_language_vectorized_cast.py
View file @
29051439
...
@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
...
@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
):
):
with
T
.
Kernel
(
1
,
threads
=
128
):
with
T
.
Kernel
(
1
,
threads
=
128
):
T
.
copy
(
A
,
B
)
T
.
copy
(
A
,
B
)
...
@@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
...
@@ -32,8 +32,8 @@ def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
A
:
T
.
Tensor
[(
M
,),
dtype_A
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
B
:
T
.
Tensor
[(
M
,),
dtype_B
],
# noqa: F821
):
):
with
T
.
Kernel
(
1
,
threads
=
128
):
with
T
.
Kernel
(
1
,
threads
=
128
):
A_local
=
T
.
alloc_fragment
((
M
,),
dtype_A
)
A_local
=
T
.
alloc_fragment
((
M
,),
dtype_A
)
...
@@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
...
@@ -73,8 +73,7 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
code
=
kernel
.
get_kernel_source
()
code
=
kernel
.
get_kernel_source
()
code_parallel
=
kernel_parallel
.
get_kernel_source
()
code_parallel
=
kernel_parallel
.
get_kernel_source
()
assert
check_str
in
code
and
check_str
in
code_parallel
,
\
assert
check_str
in
code
and
check_str
in
code_parallel
,
f
"Cast
{
src_dtype_str
}
to
{
dst_dtype_str
}
with
{
lanes
=
}
is not vectorized!"
f
"Cast
{
src_dtype_str
}
to
{
dst_dtype_str
}
with
{
lanes
=
}
is not vectorized!"
def
test_vectorized_cast
():
def
test_vectorized_cast
():
...
...
testing/python/language/test_tilelang_language_view.py
View file @
29051439
...
@@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None):
...
@@ -10,6 +10,7 @@ def view_test(N, M, dtype, new_dtype=None):
new_shape
=
[
N
//
M
,
M
]
new_shape
=
[
N
//
M
,
M
]
if
new_dtype
:
if
new_dtype
:
from
tvm
import
DataType
from
tvm
import
DataType
dtype_src
=
DataType
(
dtype
)
dtype_src
=
DataType
(
dtype
)
dtype_dst
=
DataType
(
new_dtype
)
dtype_dst
=
DataType
(
new_dtype
)
src_bits
=
dtype_src
.
bits
src_bits
=
dtype_src
.
bits
...
@@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None):
...
@@ -19,8 +20,8 @@ def view_test(N, M, dtype, new_dtype=None):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_viewed
=
T
.
view
(
A
,
new_shape
,
dtype
=
new_dtype
)
A_viewed
=
T
.
view
(
A
,
new_shape
,
dtype
=
new_dtype
)
...
@@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None):
...
@@ -37,6 +38,7 @@ def run_view(N, M, dtype, new_dtype=None):
def
ref_program
(
A
):
def
ref_program
(
A
):
if
new_dtype
:
if
new_dtype
:
from
tilelang.utils.tensor
import
map_torch_type
from
tilelang.utils.tensor
import
map_torch_type
torch_dtype
=
map_torch_type
(
new_dtype
)
torch_dtype
=
map_torch_type
(
new_dtype
)
return
A
.
view
(
N
//
M
,
M
).
view
(
dtype
=
torch_dtype
)
return
A
.
view
(
N
//
M
,
M
).
view
(
dtype
=
torch_dtype
)
return
A
.
view
(
N
//
M
,
M
)
return
A
.
view
(
N
//
M
,
M
)
...
@@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None):
...
@@ -45,7 +47,6 @@ def run_view(N, M, dtype, new_dtype=None):
def
test_reshape_view
():
def
test_reshape_view
():
# Test view with same dtype
# Test view with same dtype
run_view
(
1024
,
32
,
"float32"
)
run_view
(
1024
,
32
,
"float32"
)
run_view
(
2048
,
64
,
"float16"
)
run_view
(
2048
,
64
,
"float16"
)
...
@@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
...
@@ -61,6 +62,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
new_shape
=
[
N
//
M
,
M
+
1
]
new_shape
=
[
N
//
M
,
M
+
1
]
if
new_dtype
:
if
new_dtype
:
from
tvm
import
DataType
from
tvm
import
DataType
dtype_src
=
DataType
(
dtype
)
dtype_src
=
DataType
(
dtype
)
dtype_dst
=
DataType
(
new_dtype
)
dtype_dst
=
DataType
(
new_dtype
)
src_bits
=
dtype_src
.
bits
src_bits
=
dtype_src
.
bits
...
@@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
...
@@ -70,8 +72,8 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
N
,),
dtype
),
A
:
T
.
Tensor
((
N
,),
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
B
:
T
.
Tensor
(
new_shape
,
new_dtype
if
new_dtype
else
dtype
),
):
):
with
T
.
Kernel
(
1
)
as
_
:
with
T
.
Kernel
(
1
)
as
_
:
A_viewed
=
T
.
view
(
A
,
new_shape
,
dtype
=
new_dtype
)
A_viewed
=
T
.
view
(
A
,
new_shape
,
dtype
=
new_dtype
)
...
...
testing/python/language/test_tilelang_language_warp_reduce.py
View file @
29051439
...
@@ -7,7 +7,6 @@ import tilelang.language as T
...
@@ -7,7 +7,6 @@ import tilelang.language as T
@
tilelang
.
jit
@
tilelang
.
jit
def
get_kernel
(
reduce_op
:
str
,
dtype
:
str
):
def
get_kernel
(
reduce_op
:
str
,
dtype
:
str
):
assert
reduce_op
in
[
"sum"
,
"max"
,
"min"
,
"bitand"
,
"bitor"
]
assert
reduce_op
in
[
"sum"
,
"max"
,
"min"
,
"bitand"
,
"bitor"
]
@
T
.
prim_func
@
T
.
prim_func
...
@@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str):
...
@@ -33,16 +32,16 @@ def get_kernel(reduce_op: str, dtype: str):
def
test_warp_reduce_sum
():
def
test_warp_reduce_sum
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
'
sum
'
,
'
float32
'
)
kernel
=
get_kernel
(
"
sum
"
,
"
float32
"
)
ref
=
torch
.
full_like
(
a
,
a
.
sum
())
ref
=
torch
.
full_like
(
a
,
a
.
sum
())
kernel
(
a
)
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_max
():
def
test_warp_reduce_max
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"max"
,
'
float32
'
)
kernel
=
get_kernel
(
"max"
,
"
float32
"
)
print
(
kernel
.
get_kernel_source
())
print
(
kernel
.
get_kernel_source
())
ref
=
torch
.
full_like
(
a
,
a
.
max
())
ref
=
torch
.
full_like
(
a
,
a
.
max
())
kernel
(
a
)
kernel
(
a
)
...
@@ -50,16 +49,16 @@ def test_warp_reduce_max():
...
@@ -50,16 +49,16 @@ def test_warp_reduce_max():
def
test_warp_reduce_min
():
def
test_warp_reduce_min
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'
cuda
'
)
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"min"
,
'
float32
'
)
kernel
=
get_kernel
(
"min"
,
"
float32
"
)
ref
=
torch
.
full_like
(
a
,
a
.
min
())
ref
=
torch
.
full_like
(
a
,
a
.
min
())
kernel
(
a
)
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_bitand
():
def
test_warp_reduce_bitand
():
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"bitand"
,
'
int32
'
)
kernel
=
get_kernel
(
"bitand"
,
"
int32
"
)
ref_val
=
a
[
0
]
ref_val
=
a
[
0
]
for
i
in
range
(
1
,
a
.
shape
[
0
]):
for
i
in
range
(
1
,
a
.
shape
[
0
]):
ref_val
=
ref_val
&
a
[
i
]
ref_val
=
ref_val
&
a
[
i
]
...
@@ -69,8 +68,8 @@ def test_warp_reduce_bitand():
...
@@ -69,8 +68,8 @@ def test_warp_reduce_bitand():
def
test_warp_reduce_bitor
():
def
test_warp_reduce_bitor
():
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
kernel
=
get_kernel
(
"bitor"
,
'
int32
'
)
kernel
=
get_kernel
(
"bitor"
,
"
int32
"
)
ref_val
=
a
[
0
]
ref_val
=
a
[
0
]
for
i
in
range
(
1
,
a
.
shape
[
0
]):
for
i
in
range
(
1
,
a
.
shape
[
0
]):
ref_val
=
ref_val
|
a
[
i
]
ref_val
=
ref_val
|
a
[
i
]
...
...
testing/python/layout/test_tilelang_layout_fused_replicate.py
View file @
29051439
...
@@ -12,17 +12,16 @@ VEC_SIZE = 32
...
@@ -12,17 +12,16 @@ VEC_SIZE = 32
@
tilelang
.
jit
@
tilelang
.
jit
def
fused_index_kernel
(
B
:
int
,
M
:
int
,
N
:
int
,
BLOCK_MN
:
int
,
BLOCK_K
:
int
):
def
fused_index_kernel
(
B
:
int
,
M
:
int
,
N
:
int
,
BLOCK_MN
:
int
,
BLOCK_K
:
int
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
a
:
T
.
Buffer
((
B
,
M
,
N
),
"bfloat16"
),
a
:
T
.
Buffer
((
B
,
M
,
N
),
"bfloat16"
),
a_out
:
T
.
Buffer
((
B
,
M
,
N
),
"float32"
),
a_out
:
T
.
Buffer
((
B
,
M
,
N
),
"float32"
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
BLOCK_MN
),
T
.
ceildiv
(
M
,
BLOCK_MN
),
T
.
ceildiv
(
N
,
BLOCK_K
),
T
.
ceildiv
(
N
,
BLOCK_K
),
B
,
B
,
threads
=
128
,
threads
=
128
,
)
as
(
pid_m
,
pid_n
,
pid_b
):
)
as
(
pid_m
,
pid_n
,
pid_b
):
a_fp32_local
=
T
.
alloc_fragment
((
BLOCK_MN
*
BLOCK_K
//
VEC_SIZE
,
VEC_SIZE
),
"float32"
)
a_fp32_local
=
T
.
alloc_fragment
((
BLOCK_MN
*
BLOCK_K
//
VEC_SIZE
,
VEC_SIZE
),
"float32"
)
offs_m
=
pid_m
*
BLOCK_MN
offs_m
=
pid_m
*
BLOCK_MN
...
...
testing/python/math/test_math_bitwise_reduce.py
View file @
29051439
...
@@ -19,12 +19,11 @@ def bitwise_reduce(
...
@@ -19,12 +19,11 @@ def bitwise_reduce(
func
,
func
,
clear
=
True
,
clear
=
True
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
reduce_func
(
def
reduce_func
(
A
:
T
.
Tensor
((
M
,
N
),
"int32"
),
A
:
T
.
Tensor
((
M
,
N
),
"int32"
),
B
:
T
.
Tensor
((
M
),
"int32"
),
B
:
T
.
Tensor
((
M
),
"int32"
),
Output
:
T
.
Tensor
((
M
),
"int32"
),
Output
:
T
.
Tensor
((
M
),
"int32"
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
"int32"
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
"int32"
)
...
@@ -64,7 +63,7 @@ def run_single_bitwise_reduce(
...
@@ -64,7 +63,7 @@ def run_single_bitwise_reduce(
row_pattern
=
(
i
&
0xF
)
<<
(
i
%
4
)
# 4-bit patterns shifted by row
row_pattern
=
(
i
&
0xF
)
<<
(
i
%
4
)
# 4-bit patterns shifted by row
# Column-based pattern: different bit positions set based on column
# Column-based pattern: different bit positions set based on column
col_pattern
=
(
1
<<
(
j
%
31
)
)
# Single bit set at different positions
col_pattern
=
1
<<
(
j
%
31
)
# Single bit set at different positions
# Combine patterns with XOR to create diverse bit distributions
# Combine patterns with XOR to create diverse bit distributions
# Add some deterministic "noise" based on position
# Add some deterministic "noise" based on position
...
@@ -76,7 +75,7 @@ def run_single_bitwise_reduce(
...
@@ -76,7 +75,7 @@ def run_single_bitwise_reduce(
if
i
%
4
==
0
:
if
i
%
4
==
0
:
a
[
i
,
j
]
&=
~
(
0x1
<<
(
i
//
4
))
a
[
i
,
j
]
&=
~
(
0x1
<<
(
i
//
4
))
elif
i
%
2
==
0
:
elif
i
%
2
==
0
:
a
[
i
,
j
]
|=
(
0x1
<<
(
i
//
2
)
)
a
[
i
,
j
]
|=
0x1
<<
(
i
//
2
)
if
name
==
"reduce_bitand"
:
if
name
==
"reduce_bitand"
:
expected
=
torch
.
full
((
M
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
expected
=
torch
.
full
((
M
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
...
...
testing/python/math/test_math_fast_math.py
View file @
29051439
...
@@ -7,16 +7,16 @@ import re
...
@@ -7,16 +7,16 @@ import re
def
get_mathop_lines
(
source
,
mathop_name
):
def
get_mathop_lines
(
source
,
mathop_name
):
"""Extract lines containing the mathop from CUDA source for debugging"""
"""Extract lines containing the mathop from CUDA source for debugging"""
lines
=
source
.
split
(
'
\n
'
)
lines
=
source
.
split
(
"
\n
"
)
relevant_lines
=
[]
relevant_lines
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
mathop_name
in
line
and
(
'('
in
line
):
if
mathop_name
in
line
and
(
"("
in
line
):
# Include some context
# Include some context
start
=
max
(
0
,
i
-
1
)
start
=
max
(
0
,
i
-
1
)
end
=
min
(
len
(
lines
),
i
+
2
)
end
=
min
(
len
(
lines
),
i
+
2
)
relevant_lines
.
extend
([
f
"
{
j
}
:
{
lines
[
j
]
}
"
for
j
in
range
(
start
,
end
)])
relevant_lines
.
extend
([
f
"
{
j
}
:
{
lines
[
j
]
}
"
for
j
in
range
(
start
,
end
)])
relevant_lines
.
append
(
"---"
)
relevant_lines
.
append
(
"---"
)
return
'
\n
'
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
return
"
\n
"
.
join
(
relevant_lines
[
-
10
:])
# Show last 10 lines to avoid too much output
def
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
):
def
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
):
...
@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
...
@@ -27,9 +27,7 @@ def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
fastmath_matches
=
re
.
findall
(
fastmath_pattern
,
source
)
fastmath_matches
=
re
.
findall
(
fastmath_pattern
,
source
)
non_fastmath_matches
=
re
.
findall
(
non_fastmath_pattern
,
source
)
non_fastmath_matches
=
re
.
findall
(
non_fastmath_pattern
,
source
)
print
(
print
(
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
f
"Found
{
len
(
fastmath_matches
)
}
fastmath calls,
{
len
(
non_fastmath_matches
)
}
non-fastmath calls"
)
if
len
(
fastmath_matches
)
>
0
:
if
len
(
fastmath_matches
)
>
0
:
print
(
f
"Fastmath calls found:
{
fastmath_matches
}
"
)
print
(
f
"Fastmath calls found:
{
fastmath_matches
}
"
)
if
len
(
non_fastmath_matches
)
>
0
:
if
len
(
non_fastmath_matches
)
>
0
:
...
@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
...
@@ -51,13 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
)
check_fastmath_usage
(
source
,
mathop_name
,
expect_fastmath
=
False
)
def
run_single_arg_mathop_test
(
mathop_name
,
def
run_single_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
"""
Test single-argument mathops.
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...
@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
...
@@ -65,13 +57,12 @@ def run_single_arg_mathop_test(mathop_name,
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
bx
*
block_N
+
j
])
# Test with FAST_MATH disabled
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
kernel_no_fastmath
=
tilelang
.
compile
(
...
@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
...
@@ -80,7 +71,8 @@ def run_single_arg_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
...
@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
...
@@ -93,28 +85,22 @@ def run_single_arg_mathop_test(mathop_name,
print
(
f
"✓
{
mathop_name
}
compilation and execution test passed"
)
print
(
f
"✓
{
mathop_name
}
compilation and execution test passed"
)
def
run_two_arg_mathop_test
(
mathop_name
,
def
run_two_arg_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
"""
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
)
)
# Test with FAST_MATH disabled
# Test with FAST_MATH disabled
kernel_no_fastmath
=
tilelang
.
compile
(
kernel_no_fastmath
=
tilelang
.
compile
(
...
@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
...
@@ -123,7 +109,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
# Test with FAST_MATH enabled
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
kernel_fastmath
=
tilelang
.
compile
(
...
@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
...
@@ -132,7 +119,8 @@ def run_two_arg_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_no_fastmath
=
kernel_no_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
...
@@ -171,8 +159,8 @@ def run_abs_test():
...
@@ -171,8 +159,8 @@ def run_abs_test():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
...
@@ -184,7 +172,8 @@ def run_abs_test():
...
@@ -184,7 +172,8 @@ def run_abs_test():
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
source
=
kernel
.
get_kernel_source
()
source
=
kernel
.
get_kernel_source
()
print
(
"
\n
=== Testing abs (maps to fabs) ==="
)
print
(
"
\n
=== Testing abs (maps to fabs) ==="
)
...
@@ -199,26 +188,19 @@ def run_abs_test():
...
@@ -199,26 +188,19 @@ def run_abs_test():
print
(
"✓ abs numerical test passed"
)
print
(
"✓ abs numerical test passed"
)
def
run_fastmath_mathop_test
(
mathop_name
,
def
run_fastmath_mathop_test
(
mathop_name
,
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
mathop_func
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
"""
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
])
bx
*
block_N
+
j
])
# Test with FAST_MATH enabled
# Test with FAST_MATH enabled
kernel_fastmath
=
tilelang
.
compile
(
kernel_fastmath
=
tilelang
.
compile
(
...
@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
...
@@ -227,14 +209,15 @@ def run_fastmath_mathop_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
source_fastmath
=
kernel_fastmath
.
get_kernel_source
()
print
(
f
"
\n
=== Testing
{
mathop_name
}
(fastmath version) ==="
)
print
(
f
"
\n
=== Testing
{
mathop_name
}
(fastmath version) ==="
)
print
(
"FAST_MATH=True:"
)
print
(
"FAST_MATH=True:"
)
# Strip the __ prefix for checking in the CUDA source
# Strip the __ prefix for checking in the CUDA source
cuda_mathop_name
=
mathop_name
.
lstrip
(
'_'
)
cuda_mathop_name
=
mathop_name
.
lstrip
(
"_"
)
check_fastmath_usage
(
source_fastmath
,
cuda_mathop_name
,
expect_fastmath
=
True
)
check_fastmath_usage
(
source_fastmath
,
cuda_mathop_name
,
expect_fastmath
=
True
)
# Test numerical correctness
# Test numerical correctness
...
...
testing/python/math/test_math_ieee_math.py
View file @
29051439
...
@@ -5,14 +5,7 @@ import tilelang.testing
...
@@ -5,14 +5,7 @@ import tilelang.testing
import
pytest
import
pytest
def
run_ieee_math_test
(
mathop_name
,
def
run_ieee_math_test
(
mathop_name
,
mathop_func
,
rounding_mode
=
"rn"
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
mathop_func
,
rounding_mode
=
"rn"
,
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float32"
):
"""
"""
Test IEEE-compliant math operations with specified rounding modes.
Test IEEE-compliant math operations with specified rounding modes.
"""
"""
...
@@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name,
...
@@ -22,18 +15,19 @@ def run_ieee_math_test(mathop_name,
@
T
.
prim_func
@
T
.
prim_func
def
main_func
(
def
main_func
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
),
D
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
D
[
by
*
block_M
+
i
,
D
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
C
[
by
*
block_M
+
i
,
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
bx
*
block_N
+
j
],
rounding_mode
)
rounding_mode
,
)
out_idx
=
[
3
]
out_idx
=
[
3
]
num_inputs
=
3
num_inputs
=
3
...
@@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name,
...
@@ -41,16 +35,15 @@ def run_ieee_math_test(mathop_name,
@
T
.
prim_func
@
T
.
prim_func
def
main_func
(
def
main_func
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
B
[
by
*
block_M
+
i
,
)
bx
*
block_N
+
j
],
rounding_mode
)
out_idx
=
[
2
]
out_idx
=
[
2
]
num_inputs
=
2
num_inputs
=
2
...
@@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name,
...
@@ -58,14 +51,12 @@ def run_ieee_math_test(mathop_name,
@
T
.
prim_func
@
T
.
prim_func
def
main_func
(
def
main_func
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
)
bx
*
block_N
+
j
]
=
mathop_func
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
],
rounding_mode
)
out_idx
=
[
1
]
out_idx
=
[
1
]
num_inputs
=
1
num_inputs
=
1
...
@@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name,
...
@@ -77,7 +68,8 @@ def run_ieee_math_test(mathop_name,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
print
(
f
"
\n
=== Testing
{
mathop_name
}
with rounding mode
{
rounding_mode
}
==="
)
print
(
f
"
\n
=== Testing
{
mathop_name
}
with rounding mode
{
rounding_mode
}
==="
)
print
(
f
"✓
{
mathop_name
}
compilation test passed"
)
print
(
f
"✓
{
mathop_name
}
compilation test passed"
)
...
@@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only():
...
@@ -194,8 +186,8 @@ def test_ieee_frsqrt_rn_only():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
128
,
32
),
T
.
ceildiv
(
128
,
32
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
128
,
32
),
T
.
ceildiv
(
128
,
32
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
32
,
32
):
for
i
,
j
in
T
.
Parallel
(
32
,
32
):
...
@@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only():
...
@@ -207,7 +199,8 @@ def test_ieee_frsqrt_rn_only():
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
False
,
})
},
)
print
(
"
\n
=== Testing ieee_frsqrt (rn only) ==="
)
print
(
"
\n
=== Testing ieee_frsqrt (rn only) ==="
)
print
(
"✓ ieee_frsqrt compilation test passed"
)
print
(
"✓ ieee_frsqrt compilation test passed"
)
...
...
testing/python/metal/test_metal_codegen.py
View file @
29051439
...
@@ -5,18 +5,17 @@ import tilelang.language as T
...
@@ -5,18 +5,17 @@ import tilelang.language as T
import
torch
import
torch
@
tilelang
.
jit
(
execution_backend
=
'
torch
'
)
@
tilelang
.
jit
(
execution_backend
=
"
torch
"
)
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float32"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float32"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm
(
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
'
shared
'
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
,
scope
=
"
shared
"
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
,
scope
=
'
shared
'
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
,
scope
=
"
shared
"
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
...
@@ -48,13 +47,13 @@ def assert_gemm(
...
@@ -48,13 +47,13 @@ def assert_gemm(
torch_dtype
=
getattr
(
torch
,
dtype
)
torch_dtype
=
getattr
(
torch
,
dtype
)
a
,
b
=
None
,
None
a
,
b
=
None
,
None
if
'
int
'
in
dtype
:
if
"
int
"
in
dtype
:
a
=
torch
.
randint
(
100
,
(
M
,
K
),
dtype
=
torch_dtype
,
device
=
'
mps
'
)
a
=
torch
.
randint
(
100
,
(
M
,
K
),
dtype
=
torch_dtype
,
device
=
"
mps
"
)
b
=
torch
.
randint
(
100
,
(
K
,
N
),
dtype
=
torch_dtype
,
device
=
'
mps
'
)
b
=
torch
.
randint
(
100
,
(
K
,
N
),
dtype
=
torch_dtype
,
device
=
"
mps
"
)
else
:
else
:
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch_dtype
,
device
=
'
mps
'
)
a
=
torch
.
randn
(
M
,
K
,
dtype
=
torch_dtype
,
device
=
"
mps
"
)
b
=
torch
.
randn
(
K
,
N
,
dtype
=
torch_dtype
,
device
=
'
mps
'
)
b
=
torch
.
randn
(
K
,
N
,
dtype
=
torch_dtype
,
device
=
"
mps
"
)
c
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch_dtype
,
device
=
'
mps
'
)
c
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch_dtype
,
device
=
"
mps
"
)
jit_kernel
(
a
,
b
,
c
)
jit_kernel
(
a
,
b
,
c
)
...
@@ -70,12 +69,12 @@ def test_gemm_float32():
...
@@ -70,12 +69,12 @@ def test_gemm_float32():
@
tilelang
.
testing
.
requires_metal
@
tilelang
.
testing
.
requires_metal
def
test_gemm_float16
():
def
test_gemm_float16
():
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
'
float16
'
,
atol
=
1
)
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
"
float16
"
,
atol
=
1
)
@
tilelang
.
testing
.
requires_metal
@
tilelang
.
testing
.
requires_metal
def
test_gemm_int32
():
def
test_gemm_int32
():
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
'
int32
'
,
atol
=
1
)
assert_gemm
(
1024
,
1024
,
1024
,
16
,
16
,
16
,
dtype
=
"
int32
"
,
atol
=
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
testing/python/primitives/test_tilelang_primitives_mma.py
View file @
29051439
...
@@ -27,9 +27,9 @@ def matmul_ssr(
...
@@ -27,9 +27,9 @@ def matmul_ssr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
...
@@ -88,7 +88,8 @@ def run_matmul_ssr(
...
@@ -88,7 +88,8 @@ def run_matmul_ssr(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -106,24 +107,9 @@ def run_matmul_ssr(
...
@@ -106,24 +107,9 @@ def run_matmul_ssr(
def
test_gemm_f16f16f16_nt_ssr
():
def
test_gemm_f16f16f16_nt_ssr
():
run_matmul_ssr
(
run_matmul_ssr
(
16
,
16
,
16
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
16
,
16
,
16
,
0
,
num_threads
=
32
)
16
,
16
,
16
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
16
,
16
,
16
,
0
,
num_threads
=
32
)
run_matmul_ssr
(
128
,
128
,
128
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
32
,
32
,
32
,
0
,
num_threads
=
64
)
run_matmul_ssr
(
run_matmul_ssr
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
128
,
128
,
32
,
2
,
num_threads
=
128
)
128
,
128
,
128
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
32
,
32
,
32
,
0
,
num_threads
=
64
)
run_matmul_ssr
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
128
,
128
,
32
,
2
,
num_threads
=
128
)
def
matmul_rsr
(
def
matmul_rsr
(
...
@@ -151,9 +137,9 @@ def matmul_rsr(
...
@@ -151,9 +137,9 @@ def matmul_rsr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
...
@@ -214,7 +200,8 @@ def run_matmul_rsr(
...
@@ -214,7 +200,8 @@ def run_matmul_rsr(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -276,9 +263,9 @@ def matmul_rrr(
...
@@ -276,9 +263,9 @@ def matmul_rrr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -342,7 +329,8 @@ def run_matmul_rrr(
...
@@ -342,7 +329,8 @@ def run_matmul_rrr(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment