Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
434 additions
and
441 deletions
+434
-441
testing/python/profiler/test_tilelang_profiler.py
testing/python/profiler/test_tilelang_profiler.py
+3
-4
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
+36
-26
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
...g/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
+34
-57
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
...ython/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
+61
-65
testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py
...sform/test_tilelang_transform_Inject_software_pipeline.py
+2
-10
testing/python/transform/test_tilelang_transform_cluster_planning.py
...hon/transform/test_tilelang_transform_cluster_planning.py
+2
-5
testing/python/transform/test_tilelang_transform_config_index_bitwidth.py
...ransform/test_tilelang_transform_config_index_bitwidth.py
+31
-35
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
...n/transform/test_tilelang_transform_inject_fence_proxy.py
+36
-20
testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py
.../transform/test_tilelang_transform_inject_set_max_nreg.py
+28
-24
testing/python/transform/test_tilelang_transform_layout_inference.py
...hon/transform/test_tilelang_transform_layout_inference.py
+51
-43
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
...rm/test_tilelang_transform_legalize_safe_memory_access.py
+30
-24
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py
...sform/test_tilelang_transform_legalize_vectorized_loop.py
+6
-2
testing/python/transform/test_tilelang_transform_let_inline.py
...ng/python/transform/test_tilelang_transform_let_inline.py
+1
-4
testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py
.../transform/test_tilelang_transform_lower_hopper_intrin.py
+4
-13
testing/python/transform/test_tilelang_transform_lower_tile_op.py
...python/transform/test_tilelang_transform_lower_tile_op.py
+39
-33
testing/python/transform/test_tilelang_transform_make_packed_api.py
...thon/transform/test_tilelang_transform_make_packed_api.py
+10
-13
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
...transform/test_tilelang_transform_multi_version_buffer.py
+32
-26
testing/python/transform/test_tilelang_transform_pipeline_planning.py
...on/transform/test_tilelang_transform_pipeline_planning.py
+9
-13
testing/python/transform/test_tilelang_transform_simplify.py
testing/python/transform/test_tilelang_transform_simplify.py
+6
-7
testing/python/transform/test_tilelang_transform_thread_sync.py
...g/python/transform/test_tilelang_transform_thread_sync.py
+13
-17
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/profiler/test_tilelang_profiler.py
View file @
29051439
...
...
@@ -4,12 +4,11 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
View file @
29051439
...
...
@@ -27,9 +27,9 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -89,7 +89,8 @@ def run_gemm_ss(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
...
...
@@ -159,9 +160,9 @@ def matmul_rs(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -169,9 +170,11 @@ def matmul_rs(
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -225,7 +228,8 @@ def run_gemm_rs(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
...
...
@@ -294,9 +298,9 @@ def matmul_sr(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -304,9 +308,11 @@ def matmul_sr(
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
})
T
.
annotate_layout
(
{
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -360,7 +366,8 @@ def run_gemm_sr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
...
...
@@ -430,9 +437,9 @@ def matmul_rr(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -441,10 +448,12 @@ def matmul_rr(
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -499,7 +508,8 @@ def run_gemm_rr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
View file @
29051439
...
...
@@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
else
:
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
)
else
:
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
'cuda'
,
transposed
=
trans_A
).
to
(
map_torch_type
(
in_dtype
))
B
=
torch
.
randn
(
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
'cuda'
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
transposed
=
trans_A
).
to
(
map_torch_type
(
in_dtype
))
B
=
torch
.
randn
((
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
return
A
,
B
...
...
@@ -69,24 +53,22 @@ def matmul_sp_sm90(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
'
uint8
'
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
"
uint8
"
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
'
uint8
'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
"
uint8
"
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
}
)
T
.
disable_warp_group_reg_alloc
()
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -121,7 +103,7 @@ def matmul_sp_sm80(
trans_B
,
):
is_8_bit
=
"8"
in
in_dtype
metadata_dtype
=
'
int32
'
if
is_8_bit
else
'
int16
'
metadata_dtype
=
"
int32
"
if
is_8_bit
else
"
int16
"
E_factor
=
SparseTensorCoreIntrinEmitter
.
E_FACTOR_MAP
[
in_dtype
][
metadata_dtype
]
A_sparse_shape
=
(
M
,
K
//
2
)
if
not
trans_A
else
(
K
//
2
,
M
)
B_shape
=
(
K
,
N
)
if
not
trans_B
else
(
N
,
K
)
...
...
@@ -132,20 +114,22 @@ def matmul_sp_sm80(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -216,7 +200,7 @@ def run_gemm_sp(
C
=
_matmul
(
A
,
B
)
if
'
float8
'
in
in_dtype
:
if
"
float8
"
in
in_dtype
:
diff
=
calc_diff
(
C_sp
,
C
)
assert
diff
<
1e-3
,
f
"
{
diff
=
}
"
else
:
...
...
@@ -332,15 +316,11 @@ def test_gemm_sp_sm90():
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
0
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
2
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
False
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
False
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float8_e4m3"
,
"float16"
,
"float16"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float8_e4m3"
,
"float16"
,
"float16"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"int8"
,
"int32"
,
"int32"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
...
...
@@ -352,12 +332,9 @@ def test_gemm_sp_sm80():
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
32
,
32
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
32
,
32
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
1
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
2
,
128
)
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
View file @
29051439
...
...
@@ -34,20 +34,22 @@ def matmul(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -80,7 +82,7 @@ def run_gemm_ss(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul
(
M
,
N
,
...
...
@@ -105,7 +107,8 @@ def run_gemm_ss(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
...
...
@@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
else
:
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
)
else
:
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randn
(
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
'cuda'
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
B
=
torch
.
randn
((
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
return
A
,
B
...
...
@@ -184,8 +172,7 @@ def test_gemm_ss():
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
,
2
)
# float8 tests
run_gemm_ss
(
128
,
128
,
128
,
False
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
run_gemm_ss
(
128
,
128
,
128
,
False
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
# tfloat32 test
...
...
@@ -222,10 +209,10 @@ def matmul_rs(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -233,11 +220,13 @@ def matmul_rs(
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -271,7 +260,7 @@ def run_gemm_rs(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_rs
(
M
,
N
,
...
...
@@ -296,7 +285,8 @@ def run_gemm_rs(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
...
@@ -376,10 +366,10 @@ def matmul_sr(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -387,11 +377,13 @@ def matmul_sr(
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
T
.
annotate_layout
(
{
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -425,7 +417,7 @@ def run_gemm_sr(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_sr
(
M
,
N
,
...
...
@@ -450,7 +442,8 @@ def run_gemm_sr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
...
@@ -531,10 +524,10 @@ def matmul_rr(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -543,12 +536,14 @@ def matmul_rr(
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -583,7 +578,7 @@ def run_gemm_rr(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_rr
(
M
,
N
,
...
...
@@ -608,7 +603,8 @@ def run_gemm_rr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
...
testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py
View file @
29051439
...
...
@@ -11,22 +11,14 @@ def _check(original, transformed):
mod
=
tl
.
transform
.
Simplify
()(
mod
)
mod
=
tl
.
transform
.
LowerOpaqueBlock
()(
mod
)
mod
=
tl
.
transform
.
Simplify
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
def
test_trival_pipeline
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
16
,
1
),
"float32"
),
C
:
T
.
Tensor
((
16
,
1
),
"float32"
)):
for
tx
in
T
.
thread_binding
(
0
,
16
,
thread
=
"threadIdx.x"
):
for
i
in
T
.
serial
(
0
,
1
,
annotations
=
{
"software_pipeline_stage"
:
[
0
,
1
],
"software_pipeline_order"
:
[
0
,
1
]
}):
for
i
in
T
.
serial
(
0
,
1
,
annotations
=
{
"software_pipeline_stage"
:
[
0
,
1
],
"software_pipeline_order"
:
[
0
,
1
]}):
with
T
.
block
():
T
.
reads
(
A
[
tx
,
i
])
T
.
writes
(
C
[
tx
,
i
])
...
...
testing/python/transform/test_tilelang_transform_cluster_planning.py
View file @
29051439
...
...
@@ -21,10 +21,8 @@ def _check(original, transformed):
def
test_cluster_planning
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float16"
)):
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float16"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float16"
)
...
...
@@ -41,8 +39,7 @@ def test_cluster_planning():
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float16"
)):
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float16"
)):
T
.
func_attr
({
"clusterIdx.y"
:
T
.
int32
(
2
)})
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
...
...
testing/python/transform/test_tilelang_transform_config_index_bitwidth.py
View file @
29051439
...
...
@@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N
=
64
num_stages
=
0
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
batch
=
T
.
int32
(
batch
)
heads
=
T
.
int32
(
heads
)
...
...
@@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
shape
,
dtype
),
V_shared
:
T
.
Tensor
([
block_M
,
dim
],
dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
V
:
T
.
Tensor
(
shape
,
dtype
),
V_shared
:
T
.
Tensor
([
block_M
,
dim
],
dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
Tensor
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
logsum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
acc_s
:
T
.
Tensor
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
logsum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
...
...
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
View file @
29051439
...
...
@@ -22,7 +22,6 @@ def _check(original, transformed):
def
test_lower_fence_proxy
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -30,12 +29,15 @@ def test_lower_fence_proxy():
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
for
i
in
T
.
unroll
(
16
):
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
def
after
():
...
...
@@ -44,19 +46,21 @@ def test_lower_fence_proxy():
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
for
i
in
T
.
unroll
(
16
):
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
fence_proxy_async
()
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
_check
(
before
,
after
)
def
test_async_to_generic_no_double_fence
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence():
def
test_proxy_hint_override
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -123,7 +126,6 @@ def test_proxy_hint_override():
def
test_tma_store_sync_injection
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -154,7 +156,6 @@ def test_tma_store_sync_injection():
def
test_wgmma_marked_async
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
1
):
...
...
@@ -164,9 +165,24 @@ def test_wgmma_marked_async():
C_local
=
T
.
decl_buffer
((
32
,),
"float16"
,
scope
=
"local"
)
A_shared
[
0
]
=
T
.
float16
(
0
)
T
.
warpgroup_arrive
()
T
.
ptx_wgmma_ss
(
"float16"
,
"m64n64k16"
,
T
.
bool
(
True
),
T
.
bool
(
True
),
"fp16"
,
"fp16"
,
"fp16"
,
desc_a
.
data
,
T
.
int32
(
0
),
desc_b
.
data
,
T
.
int32
(
0
),
C_local
.
data
,
T
.
int32
(
0
),
T
.
bool
(
True
),
1
,
1
)
T
.
ptx_wgmma_ss
(
"float16"
,
"m64n64k16"
,
T
.
bool
(
True
),
T
.
bool
(
True
),
"fp16"
,
"fp16"
,
"fp16"
,
desc_a
.
data
,
T
.
int32
(
0
),
desc_b
.
data
,
T
.
int32
(
0
),
C_local
.
data
,
T
.
int32
(
0
),
T
.
bool
(
True
),
1
,
1
,
)
mod
=
tvm
.
IRModule
.
from_expr
(
before
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
mod
)
...
...
testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py
View file @
29051439
...
...
@@ -35,26 +35,25 @@ def test_inject_set_max_nreg():
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
+
3
),
T
.
bitwise_xor
(
k
//
3
%
2
,
1
))
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
)
,
k
*
32
,
by
*
64
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
3
2
,
by
*
64
,
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
else
:
# Consumer branch - should have set_max_nreg(240, 1)
for
k
in
range
(
16
):
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
# Apply the InjectSetMaxNReg pass
func
=
before
...
...
@@ -67,15 +66,18 @@ def test_inject_set_max_nreg():
set_max_nreg_calls
=
[]
def
collect_set_max_nreg
(
stmt
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
'op'
)
and
hasattr
(
stmt
.
value
.
op
,
'name'
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
"op"
)
and
hasattr
(
stmt
.
value
.
op
,
"name"
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
set_max_nreg_calls
.
append
(
stmt
.
value
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
assert
len
(
set_max_nreg_calls
)
>=
2
,
f
"Expected at least 2 set_max_nreg calls, got
{
len
(
set_max_nreg_calls
)
}
"
assert
len
(
set_max_nreg_calls
)
>=
2
,
f
"Expected at least 2 set_max_nreg calls, got
{
len
(
set_max_nreg_calls
)
}
"
print
(
"InjectSetMaxNReg test passed!"
)
...
...
@@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg():
set_max_nreg_calls
=
[]
def
collect_set_max_nreg
(
stmt
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
'op'
)
and
hasattr
(
stmt
.
value
.
op
,
'name'
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
"op"
)
and
hasattr
(
stmt
.
value
.
op
,
"name"
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
set_max_nreg_calls
.
append
(
stmt
.
value
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
# Should have no set_max_nreg calls when no_set_max_nreg is present
assert
len
(
set_max_nreg_calls
)
==
0
,
f
"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got
{
len
(
set_max_nreg_calls
)
}
"
assert
len
(
set_max_nreg_calls
)
==
0
,
f
"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got
{
len
(
set_max_nreg_calls
)
}
"
print
(
"InjectSetMaxNReg with no_set_max_nreg test passed!"
)
...
...
testing/python/transform/test_tilelang_transform_layout_inference.py
View file @
29051439
...
...
@@ -8,17 +8,21 @@ import pytest
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
])
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
],
)
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
def
before
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
...
...
@@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
for
vec
in
T
.
Parallel
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
def
after
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
for
vec
in
T
.
vectorized
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
else
:
for
vec
in
T
.
serial
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
with
tvm
.
target
.
Target
(
auto_target
):
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
...
...
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
View file @
29051439
...
...
@@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
dtype
=
"float32"
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
@@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
A_shared
[
tid
,
j
]
=
A
[
tid
+
M_offset
,
j
+
N_offset
]
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
))
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
)
)
return
main
,
expected
...
...
@@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def
issue_1013_buggy_kernel
():
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens
=
T
.
dynamic
(
'
num_tokens
'
)
num_tokens
=
T
.
dynamic
(
"
num_tokens
"
)
num_threads
=
128
@
T
.
prim_func
def
main
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
count
=
T
.
alloc_var
(
'
int
'
)
count
=
T
.
alloc_var
(
"
int
"
)
thread_idx
=
T
.
get_thread_binding
()
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
idx
=
thread_idx
+
i
*
num_threads
...
...
@@ -59,24 +62,22 @@ def issue_1013_buggy_kernel():
@
T
.
prim_func
def
expected
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
count
=
T
.
alloc_var
(
'
int
'
)
count
=
T
.
alloc_var
(
"
int
"
)
thread_idx
=
T
.
get_thread_binding
()
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
idx
=
thread_idx
+
i
*
num_threads
count
+=
T
.
Cast
(
"int32"
,
T
.
if_then_else
(
idx
<
num_tokens
,
x
[
idx
],
T
.
int64
(
0
))
==
T
.
int64
(
2
))
count
+=
T
.
Cast
(
"int32"
,
T
.
if_then_else
(
idx
<
num_tokens
,
x
[
idx
],
T
.
int64
(
0
))
==
T
.
int64
(
2
))
return
main
,
expected
def
vectorize_access_with_atmoic_add_legalize
(
M
:
int
=
64
,
N
:
int
=
64
,
M_offset
:
int
=
2
,
N_offset
:
int
=
2
):
def
vectorize_access_with_atmoic_add_legalize
(
M
:
int
=
64
,
N
:
int
=
64
,
M_offset
:
int
=
2
,
N_offset
:
int
=
2
):
dtype
=
"float32"
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
@@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64,
T
.
atomic_add
(
A
[
tid
+
M_offset
,
j
+
N_offset
],
1
)
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
))
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
)
)
# Nest if-then-else is expected, do not flatten it to pass structural equal check
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
tid
+
M_offset
<
M
:
...
...
@@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in
dtype
=
"float32"
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
tid
=
T
.
get_thread_binding
()
for
j
in
T
.
serial
(
N
):
A
[
tid
+
M_offset
,
j
+
N_offset
]
=
1
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
tid
=
T
.
get_thread_binding
()
T
.
writes
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
writes
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
tid
+
M_offset
<
M
:
...
...
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py
View file @
29051439
...
...
@@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
vec_len
=
8
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
@@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
A_shared
[
tid
,
j
,
v
]
=
A
[
tid
,
j
,
v
]
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
testing/python/transform/test_tilelang_transform_let_inline.py
View file @
29051439
...
...
@@ -8,12 +8,10 @@ def _check(original, transformed):
func
=
original
mod
=
tvm
.
IRModule
.
from_expr
(
func
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tl
.
transform
.
LetInline
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
def
test_let_binding
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
)):
for
i
in
range
(
128
):
...
...
@@ -34,7 +32,6 @@ def test_let_binding():
def
test_parallel_scope
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
128
,),
"float32"
)):
for
i
in
T
.
Parallel
(
128
):
...
...
testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py
View file @
29051439
...
...
@@ -24,7 +24,6 @@ def _check(original, transformed):
def
test_lower_hopper_intrin_barrier
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier():
v_1
=
T
.
launch_thread
(
"threadIdx.x"
,
128
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.create_barriers"
,
[
4
]))
with
T
.
If
(
v_1
==
0
),
T
.
Then
():
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
0
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
1
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
2
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
3
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
0
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
1
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
2
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
3
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.tvm_storage_sync"
,
[
"shared"
]))
_check
(
before
,
after
)
...
...
testing/python/transform/test_tilelang_transform_lower_tile_op.py
View file @
29051439
...
...
@@ -8,63 +8,69 @@ import pytest
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
])
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
],
)
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
def
before
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
def
after
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
for
vec
in
T
.
vectorized
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
else
:
for
vec
in
T
.
serial
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
with
tvm
.
transform
.
PassContext
():
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
...
...
testing/python/transform/test_tilelang_transform_make_packed_api.py
View file @
29051439
...
...
@@ -80,7 +80,6 @@ def test_target_host_removed():
@
I
.
ir_module
class
before
:
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"cuda"
,
host
=
host
)})
...
...
@@ -102,7 +101,6 @@ def test_internal_subroutine_call():
@
I
.
ir_module
class
before
:
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
...
...
@@ -121,7 +119,8 @@ def test_internal_subroutine_call():
subroutine_call_op
=
compute_scope
.
body
.
value
.
op
assert
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
GlobalVar
),
(
f
"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
def
test_subroutine_call_to_externally_visible_subroutine
():
...
...
@@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine():
@
I
.
ir_module
class
before
:
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
...
...
@@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine():
assert
subroutine_compute_scope
is
not
None
subroutine_call_op
=
main_compute_scope
.
body
.
value
.
op
assert
(
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
Op
)
and
subroutine_call_op
.
name
==
"tir.tvm_call_cpacked"
),
(
f
"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
assert
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
Op
)
and
subroutine_call_op
.
name
==
"tir.tvm_call_cpacked"
,
(
f
"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
@
tilelang
.
testing
.
requires_llvm
...
...
@@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count():
@
T
.
prim_func
def
func
(
A
:
T
.
Buffer
([
16
,
16
],
"int32"
),
B
:
T
.
Buffer
([
16
,
16
],
"int32"
),
C
:
T
.
Buffer
([
16
,
16
],
"int32"
),
D
:
T
.
Buffer
([
16
,
16
],
"int32"
),
A
:
T
.
Buffer
([
16
,
16
],
"int32"
),
B
:
T
.
Buffer
([
16
,
16
],
"int32"
),
C
:
T
.
Buffer
([
16
,
16
],
"int32"
),
D
:
T
.
Buffer
([
16
,
16
],
"int32"
),
):
pass
...
...
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
View file @
29051439
...
...
@@ -31,7 +31,6 @@ block_K = 32
def
test_multi_version_buffer
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
...
...
@@ -49,21 +48,27 @@ def test_multi_version_buffer():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
2
),
k
*
32
,
by
*
64
)
k
*
32
,
by
*
64
,
)
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
2
),
bx
*
64
,
k
*
32
)
bx
*
64
,
k
*
32
,
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
...
...
@@ -82,31 +87,32 @@ def test_multi_version_buffer():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
)
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
,
)
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
)
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
,
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
16
"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
)
)
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
32
"
),
C_local
.
data
,
0
,
32
,
3
),
)
_check
(
before
,
after
)
def
test_multi_version_buffer_with_let
():
@
T
.
prim_func
def
before
(
scales
:
T
.
Tensor
((
4
,),
"float32"
)):
with
T
.
block
(
"root"
):
...
...
testing/python/transform/test_tilelang_transform_pipeline_planning.py
View file @
29051439
...
...
@@ -19,10 +19,8 @@ def _check(original, transformed):
def
test_simple_pipeline
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float32"
)):
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float32"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
...
...
@@ -39,8 +37,7 @@ def test_simple_pipeline():
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float32"
)):
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float32"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
...
...
@@ -49,14 +46,13 @@ def test_simple_pipeline():
T
.
clear
(
C_local
)
for
ko
in
T
.
serial
(
32
,
annotations
=
{
"software_pipeline_async_stages"
:
[
T
.
int32
(
0
)],
"software_pipeline_order"
:
[
T
.
int32
(
0
),
T
.
int32
(
1
),
T
.
int32
(
2
)],
"software_pipeline_stage"
:
[
T
.
int32
(
3
),
T
.
int32
(
3
),
T
.
int32
(
3
)]
}):
32
,
annotations
=
{
"software_pipeline_async_stages"
:
[
T
.
int32
(
0
)],
"software_pipeline_order"
:
[
T
.
int32
(
0
),
T
.
int32
(
1
),
T
.
int32
(
2
)],
"software_pipeline_stage"
:
[
T
.
int32
(
3
),
T
.
int32
(
3
),
T
.
int32
(
3
)],
},
):
T
.
copy
(
A
[
by
*
128
,
ko
*
32
],
A_shared
)
T
.
copy
(
B
[
ko
*
32
,
bx
*
128
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
...
...
testing/python/transform/test_tilelang_transform_simplify.py
View file @
29051439
...
...
@@ -8,14 +8,13 @@ def modify(
with_B
:
bool
=
False
,
with_bias
:
bool
=
False
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
64
,
64
)),
B
:
T
.
Tensor
((
64
,
64
)),
C
:
T
.
Tensor
((
64
,
64
)),
D
:
T
.
Tensor
((
64
,
64
)),
bias
:
T
.
Tensor
((
64
,
64
)),
A
:
T
.
Tensor
((
64
,
64
)),
B
:
T
.
Tensor
((
64
,
64
)),
C
:
T
.
Tensor
((
64
,
64
)),
D
:
T
.
Tensor
((
64
,
64
)),
bias
:
T
.
Tensor
((
64
,
64
)),
):
if
with_B
:
if
with_bias
:
...
...
@@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
a
:
T
.
handle
,
...
...
@@ -76,6 +74,7 @@ def test_matmul():
kernel
=
tl
.
compile
(
mod
[
"main"
],
out_idx
=
[
2
])
import
torch
a
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
b
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
c
=
kernel
(
a
,
b
)
...
...
testing/python/transform/test_tilelang_transform_thread_sync.py
View file @
29051439
...
...
@@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc):
cuda_target
=
tvm
.
target
.
Target
(
"cuda"
,
host
=
"llvm"
)
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
({
"global_symbol"
:
"test"
,
"target"
:
cuda_target
}))(
mod
)
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
({
"global_symbol"
:
"test"
,
"target"
:
cuda_target
}))(
mod
)
mod
=
tvm
.
tir
.
transform
.
AnnotateDeviceRegions
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
SplitHostDevice
()(
mod
)
...
...
@@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc):
@
tilelang
.
testing
.
requires_cuda
def
test_sync_if_with_same_index
():
@
T
.
prim_func
(
check_well_formed
=
False
)
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
...
...
@@ -47,7 +42,6 @@ def test_sync_if_with_same_index():
@
tilelang
.
testing
.
requires_cuda
def
test_sync_read_thread_id_independent_location
():
@
T
.
prim_func
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
...
...
@@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location():
@
tilelang
.
testing
.
requires_cuda
def
test_sync_shared
():
@
T
.
prim_func
(
private
=
True
)
def
func
(
A
:
T
.
Buffer
((
4
,
4
),
"float32"
),
E
:
T
.
Buffer
((
4
,
4
),
"float32"
)):
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
1
)
...
...
@@ -113,7 +106,6 @@ def test_sync_shared():
@
tvm
.
testing
.
requires_cuda
def
test_sync_let_stmt
():
@
T
.
prim_func
(
private
=
True
)
def
func
(
A
:
T
.
Buffer
((
16
*
512
),
"float32"
)):
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
16
)
...
...
@@ -136,9 +128,9 @@ def test_sync_let_stmt():
in_thread_A_temp_1
[
0
]
=
A_temp
cross_thread_A_temp_1
=
T
.
Buffer
((
1
,),
data
=
cross_thread_A_temp
,
scope
=
"local"
)
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x0
,
y0
:
x0
+
y0
,
[
T
.
float32
(
0
)]),
"reduce_scope"
,
T
.
reinterpret
(
"handle"
,
T
.
uint64
(
0
)),
T
.
comm_reducer
(
lambda
x0
,
y0
:
x0
+
y0
,
[
T
.
float32
(
0
)]),
"reduce_scope"
,
T
.
reinterpret
(
"handle"
,
T
.
uint64
(
0
)),
):
T
.
tvm_thread_allreduce
(
T
.
uint32
(
1
),
...
...
@@ -190,16 +182,19 @@ def test_sync_let_stmt():
@
tilelang
.
testing
.
requires_cuda
def
test_sync_shared_dyn_stmatrix_loop_hoist
():
@
T
.
prim_func
def
func
():
buf_dyn_shmem
=
T
.
alloc_buffer
((
98304
,),
"uint8"
,
scope
=
"shared.dyn"
)
tx
=
T
.
launch_thread
(
"threadIdx.x"
,
384
)
for
i
in
T
.
unroll
(
8
):
off
=
(
i
//
4
*
8192
+
tx
//
32
*
1024
+
tx
%
16
*
64
+
(
tx
%
8
//
4
+
i
%
4
//
2
)
%
2
*
32
+
(
tx
%
4
//
2
+
i
%
2
)
%
2
*
16
+
(
tx
%
32
//
16
+
tx
%
2
)
%
2
*
8
)
i
//
4
*
8192
+
tx
//
32
*
1024
+
tx
%
16
*
64
+
(
tx
%
8
//
4
+
i
%
4
//
2
)
%
2
*
32
+
(
tx
%
4
//
2
+
i
%
2
)
%
2
*
16
+
(
tx
%
32
//
16
+
tx
%
2
)
%
2
*
8
)
T
.
evaluate
(
T
.
call_intrin
(
"handle"
,
...
...
@@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist():
2
,
),
T
.
int32
(
2
),
))
)
)
mod
=
tvm
.
IRModule
({
"main"
:
func
})
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment