Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
434 additions
and
441 deletions
+434
-441
testing/python/profiler/test_tilelang_profiler.py
testing/python/profiler/test_tilelang_profiler.py
+3
-4
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
+36
-26
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
...g/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
+34
-57
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
...ython/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
+61
-65
testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py
...sform/test_tilelang_transform_Inject_software_pipeline.py
+2
-10
testing/python/transform/test_tilelang_transform_cluster_planning.py
...hon/transform/test_tilelang_transform_cluster_planning.py
+2
-5
testing/python/transform/test_tilelang_transform_config_index_bitwidth.py
...ransform/test_tilelang_transform_config_index_bitwidth.py
+31
-35
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
...n/transform/test_tilelang_transform_inject_fence_proxy.py
+36
-20
testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py
.../transform/test_tilelang_transform_inject_set_max_nreg.py
+28
-24
testing/python/transform/test_tilelang_transform_layout_inference.py
...hon/transform/test_tilelang_transform_layout_inference.py
+51
-43
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
...rm/test_tilelang_transform_legalize_safe_memory_access.py
+30
-24
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py
...sform/test_tilelang_transform_legalize_vectorized_loop.py
+6
-2
testing/python/transform/test_tilelang_transform_let_inline.py
...ng/python/transform/test_tilelang_transform_let_inline.py
+1
-4
testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py
.../transform/test_tilelang_transform_lower_hopper_intrin.py
+4
-13
testing/python/transform/test_tilelang_transform_lower_tile_op.py
...python/transform/test_tilelang_transform_lower_tile_op.py
+39
-33
testing/python/transform/test_tilelang_transform_make_packed_api.py
...thon/transform/test_tilelang_transform_make_packed_api.py
+10
-13
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
...transform/test_tilelang_transform_multi_version_buffer.py
+32
-26
testing/python/transform/test_tilelang_transform_pipeline_planning.py
...on/transform/test_tilelang_transform_pipeline_planning.py
+9
-13
testing/python/transform/test_tilelang_transform_simplify.py
testing/python/transform/test_tilelang_transform_simplify.py
+6
-7
testing/python/transform/test_tilelang_transform_thread_sync.py
...g/python/transform/test_tilelang_transform_thread_sync.py
+13
-17
No files found.
testing/python/profiler/test_tilelang_profiler.py
View file @
29051439
...
...
@@ -4,7 +4,6 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
View file @
29051439
...
...
@@ -89,7 +89,8 @@ def run_gemm_ss(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
...
...
@@ -169,9 +170,11 @@ def matmul_rs(
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
})
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -225,7 +228,8 @@ def run_gemm_rs(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
...
...
@@ -304,9 +308,11 @@ def matmul_sr(
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
})
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -360,7 +366,8 @@ def run_gemm_sr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
...
...
@@ -441,10 +448,12 @@ def matmul_rr(
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
})
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -499,7 +508,8 @@ def run_gemm_rr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
View file @
29051439
...
...
@@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
else
:
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
)
else
:
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
'cuda'
,
transposed
=
trans_A
).
to
(
map_torch_type
(
in_dtype
))
B
=
torch
.
randn
(
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
'cuda'
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
transposed
=
trans_A
).
to
(
map_torch_type
(
in_dtype
))
B
=
torch
.
randn
((
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
return
A
,
B
...
...
@@ -70,23 +54,21 @@ def matmul_sp_sm90(
@
T
.
prim_func
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
'
uint8
'
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
"
uint8
"
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
'
uint8
'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
"
uint8
"
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
})
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
}
)
T
.
disable_warp_group_reg_alloc
()
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -121,7 +103,7 @@ def matmul_sp_sm80(
trans_B
,
):
is_8_bit
=
"8"
in
in_dtype
metadata_dtype
=
'
int32
'
if
is_8_bit
else
'
int16
'
metadata_dtype
=
"
int32
"
if
is_8_bit
else
"
int16
"
E_factor
=
SparseTensorCoreIntrinEmitter
.
E_FACTOR_MAP
[
in_dtype
][
metadata_dtype
]
A_sparse_shape
=
(
M
,
K
//
2
)
if
not
trans_A
else
(
K
//
2
,
M
)
B_shape
=
(
K
,
N
)
if
not
trans_B
else
(
N
,
K
)
...
...
@@ -142,10 +124,12 @@ def matmul_sp_sm80(
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -216,7 +200,7 @@ def run_gemm_sp(
C
=
_matmul
(
A
,
B
)
if
'
float8
'
in
in_dtype
:
if
"
float8
"
in
in_dtype
:
diff
=
calc_diff
(
C_sp
,
C
)
assert
diff
<
1e-3
,
f
"
{
diff
=
}
"
else
:
...
...
@@ -332,15 +316,11 @@ def test_gemm_sp_sm90():
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
0
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
2
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
False
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
False
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float8_e4m3"
,
"float16"
,
"float16"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float8_e4m3"
,
"float16"
,
"float16"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"int8"
,
"int32"
,
"int32"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
...
...
@@ -352,12 +332,9 @@ def test_gemm_sp_sm80():
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
32
,
32
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
32
,
32
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
1
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
2
,
128
)
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
View file @
29051439
...
...
@@ -44,10 +44,12 @@ def matmul(
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -80,7 +82,7 @@ def run_gemm_ss(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul
(
M
,
N
,
...
...
@@ -105,7 +107,8 @@ def run_gemm_ss(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
...
...
@@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
else
:
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
)
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
)
else
:
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randn
(
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
'cuda'
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
B
=
torch
.
randn
((
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
return
A
,
B
...
...
@@ -184,8 +172,7 @@ def test_gemm_ss():
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
,
2
)
# float8 tests
run_gemm_ss
(
128
,
128
,
128
,
False
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
run_gemm_ss
(
128
,
128
,
128
,
False
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
# tfloat32 test
...
...
@@ -233,11 +220,13 @@ def matmul_rs(
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -271,7 +260,7 @@ def run_gemm_rs(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_rs
(
M
,
N
,
...
...
@@ -296,7 +285,8 @@ def run_gemm_rs(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
...
@@ -387,11 +377,13 @@ def matmul_sr(
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -425,7 +417,7 @@ def run_gemm_sr(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_sr
(
M
,
N
,
...
...
@@ -450,7 +442,8 @@ def run_gemm_sr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
...
@@ -543,12 +536,14 @@ def matmul_rr(
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
}
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
...
@@ -583,7 +578,7 @@ def run_gemm_rr(
num_stages
=
3
,
num_threads
=
128
,
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_rr
(
M
,
N
,
...
...
@@ -608,7 +603,8 @@ def run_gemm_rr(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
...
testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py
View file @
29051439
...
...
@@ -11,22 +11,14 @@ def _check(original, transformed):
mod
=
tl
.
transform
.
Simplify
()(
mod
)
mod
=
tl
.
transform
.
LowerOpaqueBlock
()(
mod
)
mod
=
tl
.
transform
.
Simplify
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
def
test_trival_pipeline
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
16
,
1
),
"float32"
),
C
:
T
.
Tensor
((
16
,
1
),
"float32"
)):
for
tx
in
T
.
thread_binding
(
0
,
16
,
thread
=
"threadIdx.x"
):
for
i
in
T
.
serial
(
0
,
1
,
annotations
=
{
"software_pipeline_stage"
:
[
0
,
1
],
"software_pipeline_order"
:
[
0
,
1
]
}):
for
i
in
T
.
serial
(
0
,
1
,
annotations
=
{
"software_pipeline_stage"
:
[
0
,
1
],
"software_pipeline_order"
:
[
0
,
1
]}):
with
T
.
block
():
T
.
reads
(
A
[
tx
,
i
])
T
.
writes
(
C
[
tx
,
i
])
...
...
testing/python/transform/test_tilelang_transform_cluster_planning.py
View file @
29051439
...
...
@@ -21,10 +21,8 @@ def _check(original, transformed):
def
test_cluster_planning
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float16"
)):
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float16"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float16"
)
...
...
@@ -41,8 +39,7 @@ def test_cluster_planning():
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float16"
)):
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float16"
)):
T
.
func_attr
({
"clusterIdx.y"
:
T
.
int32
(
2
)})
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
...
...
testing/python/transform/test_tilelang_transform_config_index_bitwidth.py
View file @
29051439
...
...
@@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N
=
64
num_stages
=
0
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
batch
=
T
.
int32
(
batch
)
heads
=
T
.
int32
(
heads
)
...
...
@@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -36,11 +35,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -55,7 +53,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
...
...
@@ -106,8 +104,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
...
...
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
View file @
29051439
...
...
@@ -22,7 +22,6 @@ def _check(original, transformed):
def
test_lower_fence_proxy
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -30,12 +29,15 @@ def test_lower_fence_proxy():
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
for
i
in
T
.
unroll
(
16
):
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
def
after
():
...
...
@@ -44,19 +46,21 @@ def test_lower_fence_proxy():
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
for
i
in
T
.
unroll
(
16
):
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
fence_proxy_async
()
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
_check
(
before
,
after
)
def
test_async_to_generic_no_double_fence
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence():
def
test_proxy_hint_override
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -123,7 +126,6 @@ def test_proxy_hint_override():
def
test_tma_store_sync_injection
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -154,7 +156,6 @@ def test_tma_store_sync_injection():
def
test_wgmma_marked_async
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
1
):
...
...
@@ -164,9 +165,24 @@ def test_wgmma_marked_async():
C_local
=
T
.
decl_buffer
((
32
,),
"float16"
,
scope
=
"local"
)
A_shared
[
0
]
=
T
.
float16
(
0
)
T
.
warpgroup_arrive
()
T
.
ptx_wgmma_ss
(
"float16"
,
"m64n64k16"
,
T
.
bool
(
True
),
T
.
bool
(
True
),
"fp16"
,
"fp16"
,
"fp16"
,
desc_a
.
data
,
T
.
int32
(
0
),
desc_b
.
data
,
T
.
int32
(
0
),
C_local
.
data
,
T
.
int32
(
0
),
T
.
bool
(
True
),
1
,
1
)
T
.
ptx_wgmma_ss
(
"float16"
,
"m64n64k16"
,
T
.
bool
(
True
),
T
.
bool
(
True
),
"fp16"
,
"fp16"
,
"fp16"
,
desc_a
.
data
,
T
.
int32
(
0
),
desc_b
.
data
,
T
.
int32
(
0
),
C_local
.
data
,
T
.
int32
(
0
),
T
.
bool
(
True
),
1
,
1
,
)
mod
=
tvm
.
IRModule
.
from_expr
(
before
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
mod
)
...
...
testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py
View file @
29051439
...
...
@@ -35,26 +35,25 @@ def test_inject_set_max_nreg():
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
+
3
),
T
.
bitwise_xor
(
k
//
3
%
2
,
1
))
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
)
,
k
*
32
,
by
*
64
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
3
2
,
by
*
64
,
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
else
:
# Consumer branch - should have set_max_nreg(240, 1)
for
k
in
range
(
16
):
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
# Apply the InjectSetMaxNReg pass
func
=
before
...
...
@@ -67,15 +66,18 @@ def test_inject_set_max_nreg():
set_max_nreg_calls
=
[]
def
collect_set_max_nreg
(
stmt
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
'op'
)
and
hasattr
(
stmt
.
value
.
op
,
'name'
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
"op"
)
and
hasattr
(
stmt
.
value
.
op
,
"name"
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
set_max_nreg_calls
.
append
(
stmt
.
value
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
assert
len
(
set_max_nreg_calls
)
>=
2
,
f
"Expected at least 2 set_max_nreg calls, got
{
len
(
set_max_nreg_calls
)
}
"
assert
len
(
set_max_nreg_calls
)
>=
2
,
f
"Expected at least 2 set_max_nreg calls, got
{
len
(
set_max_nreg_calls
)
}
"
print
(
"InjectSetMaxNReg test passed!"
)
...
...
@@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg():
set_max_nreg_calls
=
[]
def
collect_set_max_nreg
(
stmt
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
'op'
)
and
hasattr
(
stmt
.
value
.
op
,
'name'
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
"op"
)
and
hasattr
(
stmt
.
value
.
op
,
"name"
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
set_max_nreg_calls
.
append
(
stmt
.
value
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
# Should have no set_max_nreg calls when no_set_max_nreg is present
assert
len
(
set_max_nreg_calls
)
==
0
,
f
"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got
{
len
(
set_max_nreg_calls
)
}
"
assert
len
(
set_max_nreg_calls
)
==
0
,
f
"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got
{
len
(
set_max_nreg_calls
)
}
"
print
(
"InjectSetMaxNReg with no_set_max_nreg test passed!"
)
...
...
testing/python/transform/test_tilelang_transform_layout_inference.py
View file @
29051439
...
...
@@ -8,17 +8,21 @@ import pytest
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
])
],
)
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
def
before
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
...
...
@@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
for
vec
in
T
.
Parallel
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
def
after
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
for
vec
in
T
.
vectorized
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
else
:
for
vec
in
T
.
serial
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
with
tvm
.
target
.
Target
(
auto_target
):
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
...
...
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
View file @
29051439
...
...
@@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
dtype
=
"float32"
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
@@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
A_shared
[
tid
,
j
]
=
A
[
tid
+
M_offset
,
j
+
N_offset
]
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
))
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
)
)
return
main
,
expected
...
...
@@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def
issue_1013_buggy_kernel
():
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens
=
T
.
dynamic
(
'
num_tokens
'
)
num_tokens
=
T
.
dynamic
(
"
num_tokens
"
)
num_threads
=
128
@
T
.
prim_func
def
main
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
count
=
T
.
alloc_var
(
'
int
'
)
count
=
T
.
alloc_var
(
"
int
"
)
thread_idx
=
T
.
get_thread_binding
()
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
idx
=
thread_idx
+
i
*
num_threads
...
...
@@ -59,24 +62,22 @@ def issue_1013_buggy_kernel():
@
T
.
prim_func
def
expected
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
count
=
T
.
alloc_var
(
'
int
'
)
count
=
T
.
alloc_var
(
"
int
"
)
thread_idx
=
T
.
get_thread_binding
()
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
idx
=
thread_idx
+
i
*
num_threads
count
+=
T
.
Cast
(
"int32"
,
T
.
if_then_else
(
idx
<
num_tokens
,
x
[
idx
],
T
.
int64
(
0
))
==
T
.
int64
(
2
))
count
+=
T
.
Cast
(
"int32"
,
T
.
if_then_else
(
idx
<
num_tokens
,
x
[
idx
],
T
.
int64
(
0
))
==
T
.
int64
(
2
))
return
main
,
expected
def
vectorize_access_with_atmoic_add_legalize
(
M
:
int
=
64
,
N
:
int
=
64
,
M_offset
:
int
=
2
,
N_offset
:
int
=
2
):
def
vectorize_access_with_atmoic_add_legalize
(
M
:
int
=
64
,
N
:
int
=
64
,
M_offset
:
int
=
2
,
N_offset
:
int
=
2
):
dtype
=
"float32"
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
@@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64,
T
.
atomic_add
(
A
[
tid
+
M_offset
,
j
+
N_offset
],
1
)
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
))
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
)
)
# Nest if-then-else is expected, do not flatten it to pass structural equal check
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
tid
+
M_offset
<
M
:
...
...
@@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in
dtype
=
"float32"
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
tid
=
T
.
get_thread_binding
()
for
j
in
T
.
serial
(
N
):
A
[
tid
+
M_offset
,
j
+
N_offset
]
=
1
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
tid
=
T
.
get_thread_binding
()
T
.
writes
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
writes
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
tid
+
M_offset
<
M
:
...
...
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py
View file @
29051439
...
...
@@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
vec_len
=
8
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
@@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
A_shared
[
tid
,
j
,
v
]
=
A
[
tid
,
j
,
v
]
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
...
...
testing/python/transform/test_tilelang_transform_let_inline.py
View file @
29051439
...
...
@@ -8,12 +8,10 @@ def _check(original, transformed):
func
=
original
mod
=
tvm
.
IRModule
.
from_expr
(
func
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tl
.
transform
.
LetInline
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
def
test_let_binding
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
)):
for
i
in
range
(
128
):
...
...
@@ -34,7 +32,6 @@ def test_let_binding():
def
test_parallel_scope
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
128
,),
"float32"
)):
for
i
in
T
.
Parallel
(
128
):
...
...
testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py
View file @
29051439
...
...
@@ -24,7 +24,6 @@ def _check(original, transformed):
def
test_lower_hopper_intrin_barrier
():
@
T
.
prim_func
def
before
():
with
T
.
Kernel
(
8
):
...
...
@@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier():
v_1
=
T
.
launch_thread
(
"threadIdx.x"
,
128
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.create_barriers"
,
[
4
]))
with
T
.
If
(
v_1
==
0
),
T
.
Then
():
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
0
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
1
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
2
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
3
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
0
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
1
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
2
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
3
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.tvm_storage_sync"
,
[
"shared"
]))
_check
(
before
,
after
)
...
...
testing/python/transform/test_tilelang_transform_lower_tile_op.py
View file @
29051439
...
...
@@ -8,63 +8,69 @@ import pytest
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
])
],
)
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
def
before
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
def
after
():
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
for
vec
in
T
.
vectorized
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
else
:
for
vec
in
T
.
serial
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
]
=
T
.
if_then_else
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
],
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
with
tvm
.
transform
.
PassContext
():
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
...
...
testing/python/transform/test_tilelang_transform_make_packed_api.py
View file @
29051439
...
...
@@ -80,7 +80,6 @@ def test_target_host_removed():
@
I
.
ir_module
class
before
:
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"cuda"
,
host
=
host
)})
...
...
@@ -102,7 +101,6 @@ def test_internal_subroutine_call():
@
I
.
ir_module
class
before
:
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
...
...
@@ -121,7 +119,8 @@ def test_internal_subroutine_call():
subroutine_call_op
=
compute_scope
.
body
.
value
.
op
assert
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
GlobalVar
),
(
f
"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
def
test_subroutine_call_to_externally_visible_subroutine
():
...
...
@@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine():
@
I
.
ir_module
class
before
:
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
...
...
@@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine():
assert
subroutine_compute_scope
is
not
None
subroutine_call_op
=
main_compute_scope
.
body
.
value
.
op
assert
(
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
Op
)
and
subroutine_call_op
.
name
==
"tir.tvm_call_cpacked"
),
(
f
"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
assert
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
Op
)
and
subroutine_call_op
.
name
==
"tir.tvm_call_cpacked"
,
(
f
"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
@
tilelang
.
testing
.
requires_llvm
...
...
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
View file @
29051439
...
...
@@ -31,7 +31,6 @@ block_K = 32
def
test_multi_version_buffer
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
...
...
@@ -49,21 +48,27 @@ def test_multi_version_buffer():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
2
),
k
*
32
,
by
*
64
)
k
*
32
,
by
*
64
,
)
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
2
),
bx
*
64
,
k
*
32
)
bx
*
64
,
k
*
32
,
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
...
...
@@ -82,31 +87,32 @@ def test_multi_version_buffer():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
)
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
by
*
64
,
)
if
v
==
0
:
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
)
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
k
*
32
,
)
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
16
"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
)
)
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
32
"
),
C_local
.
data
,
0
,
32
,
3
),
)
_check
(
before
,
after
)
def
test_multi_version_buffer_with_let
():
@
T
.
prim_func
def
before
(
scales
:
T
.
Tensor
((
4
,),
"float32"
)):
with
T
.
block
(
"root"
):
...
...
testing/python/transform/test_tilelang_transform_pipeline_planning.py
View file @
29051439
...
...
@@ -19,10 +19,8 @@ def _check(original, transformed):
def
test_simple_pipeline
():
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float32"
)):
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float32"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
...
...
@@ -39,8 +37,7 @@ def test_simple_pipeline():
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
(
(
1024
,
1024
),
"float32"
)):
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float32"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
...
...
@@ -52,11 +49,10 @@ def test_simple_pipeline():
32
,
annotations
=
{
"software_pipeline_async_stages"
:
[
T
.
int32
(
0
)],
"software_pipeline_order"
:
[
T
.
int32
(
0
),
T
.
int32
(
1
),
T
.
int32
(
2
)],
"software_pipeline_stage"
:
[
T
.
int32
(
3
),
T
.
int32
(
3
),
T
.
int32
(
3
)]
}):
"software_pipeline_order"
:
[
T
.
int32
(
0
),
T
.
int32
(
1
),
T
.
int32
(
2
)],
"software_pipeline_stage"
:
[
T
.
int32
(
3
),
T
.
int32
(
3
),
T
.
int32
(
3
)],
},
):
T
.
copy
(
A
[
by
*
128
,
ko
*
32
],
A_shared
)
T
.
copy
(
B
[
ko
*
32
,
bx
*
128
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
...
...
testing/python/transform/test_tilelang_transform_simplify.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ def modify(
with_B
:
bool
=
False
,
with_bias
:
bool
=
False
,
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
64
,
64
)),
...
...
@@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
a
:
T
.
handle
,
...
...
@@ -76,6 +74,7 @@ def test_matmul():
kernel
=
tl
.
compile
(
mod
[
"main"
],
out_idx
=
[
2
])
import
torch
a
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
b
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
c
=
kernel
(
a
,
b
)
...
...
testing/python/transform/test_tilelang_transform_thread_sync.py
View file @
29051439
...
...
@@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc):
cuda_target
=
tvm
.
target
.
Target
(
"cuda"
,
host
=
"llvm"
)
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
({
"global_symbol"
:
"test"
,
"target"
:
cuda_target
}))(
mod
)
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
({
"global_symbol"
:
"test"
,
"target"
:
cuda_target
}))(
mod
)
mod
=
tvm
.
tir
.
transform
.
AnnotateDeviceRegions
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
SplitHostDevice
()(
mod
)
...
...
@@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc):
@
tilelang
.
testing
.
requires_cuda
def
test_sync_if_with_same_index
():
@
T
.
prim_func
(
check_well_formed
=
False
)
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
...
...
@@ -47,7 +42,6 @@ def test_sync_if_with_same_index():
@
tilelang
.
testing
.
requires_cuda
def
test_sync_read_thread_id_independent_location
():
@
T
.
prim_func
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
...
...
@@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location():
@
tilelang
.
testing
.
requires_cuda
def
test_sync_shared
():
@
T
.
prim_func
(
private
=
True
)
def
func
(
A
:
T
.
Buffer
((
4
,
4
),
"float32"
),
E
:
T
.
Buffer
((
4
,
4
),
"float32"
)):
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
1
)
...
...
@@ -113,7 +106,6 @@ def test_sync_shared():
@
tvm
.
testing
.
requires_cuda
def
test_sync_let_stmt
():
@
T
.
prim_func
(
private
=
True
)
def
func
(
A
:
T
.
Buffer
((
16
*
512
),
"float32"
)):
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
16
)
...
...
@@ -190,16 +182,19 @@ def test_sync_let_stmt():
@
tilelang
.
testing
.
requires_cuda
def
test_sync_shared_dyn_stmatrix_loop_hoist
():
@
T
.
prim_func
def
func
():
buf_dyn_shmem
=
T
.
alloc_buffer
((
98304
,),
"uint8"
,
scope
=
"shared.dyn"
)
tx
=
T
.
launch_thread
(
"threadIdx.x"
,
384
)
for
i
in
T
.
unroll
(
8
):
off
=
(
i
//
4
*
8192
+
tx
//
32
*
1024
+
tx
%
16
*
64
+
(
tx
%
8
//
4
+
i
%
4
//
2
)
%
2
*
32
+
(
tx
%
4
//
2
+
i
%
2
)
%
2
*
16
+
(
tx
%
32
//
16
+
tx
%
2
)
%
2
*
8
)
i
//
4
*
8192
+
tx
//
32
*
1024
+
tx
%
16
*
64
+
(
tx
%
8
//
4
+
i
%
4
//
2
)
%
2
*
32
+
(
tx
%
4
//
2
+
i
%
2
)
%
2
*
16
+
(
tx
%
32
//
16
+
tx
%
2
)
%
2
*
8
)
T
.
evaluate
(
T
.
call_intrin
(
"handle"
,
...
...
@@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist():
2
,
),
T
.
int32
(
2
),
))
)
)
mod
=
tvm
.
IRModule
({
"main"
:
func
})
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment