Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
434 additions
and
441 deletions
+434
-441
testing/python/profiler/test_tilelang_profiler.py
testing/python/profiler/test_tilelang_profiler.py
+3
-4
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
+36
-26
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
...g/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
+34
-57
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
...ython/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
+61
-65
testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py
...sform/test_tilelang_transform_Inject_software_pipeline.py
+2
-10
testing/python/transform/test_tilelang_transform_cluster_planning.py
...hon/transform/test_tilelang_transform_cluster_planning.py
+2
-5
testing/python/transform/test_tilelang_transform_config_index_bitwidth.py
...ransform/test_tilelang_transform_config_index_bitwidth.py
+31
-35
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
...n/transform/test_tilelang_transform_inject_fence_proxy.py
+36
-20
testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py
.../transform/test_tilelang_transform_inject_set_max_nreg.py
+28
-24
testing/python/transform/test_tilelang_transform_layout_inference.py
...hon/transform/test_tilelang_transform_layout_inference.py
+51
-43
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
...rm/test_tilelang_transform_legalize_safe_memory_access.py
+30
-24
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py
...sform/test_tilelang_transform_legalize_vectorized_loop.py
+6
-2
testing/python/transform/test_tilelang_transform_let_inline.py
...ng/python/transform/test_tilelang_transform_let_inline.py
+1
-4
testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py
.../transform/test_tilelang_transform_lower_hopper_intrin.py
+4
-13
testing/python/transform/test_tilelang_transform_lower_tile_op.py
...python/transform/test_tilelang_transform_lower_tile_op.py
+39
-33
testing/python/transform/test_tilelang_transform_make_packed_api.py
...thon/transform/test_tilelang_transform_make_packed_api.py
+10
-13
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
...transform/test_tilelang_transform_multi_version_buffer.py
+32
-26
testing/python/transform/test_tilelang_transform_pipeline_planning.py
...on/transform/test_tilelang_transform_pipeline_planning.py
+9
-13
testing/python/transform/test_tilelang_transform_simplify.py
testing/python/transform/test_tilelang_transform_simplify.py
+6
-7
testing/python/transform/test_tilelang_transform_thread_sync.py
...g/python/transform/test_tilelang_transform_thread_sync.py
+13
-17
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
testing/python/profiler/test_tilelang_profiler.py
View file @
29051439
...
@@ -4,12 +4,11 @@ import tilelang.language as T
...
@@ -4,12 +4,11 @@ import tilelang.language as T
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
gemm
(
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
View file @
29051439
...
@@ -27,9 +27,9 @@ def matmul(
...
@@ -27,9 +27,9 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -89,7 +89,8 @@ def run_gemm_ss(
...
@@ -89,7 +89,8 @@ def run_gemm_ss(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
...
@@ -159,9 +160,9 @@ def matmul_rs(
...
@@ -159,9 +160,9 @@ def matmul_rs(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -169,9 +170,11 @@ def matmul_rs(
...
@@ -169,9 +170,11 @@ def matmul_rs(
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
{
})
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
@@ -225,7 +228,8 @@ def run_gemm_rs(
...
@@ -225,7 +228,8 @@ def run_gemm_rs(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -294,9 +298,9 @@ def matmul_sr(
...
@@ -294,9 +298,9 @@ def matmul_sr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -304,9 +308,11 @@ def matmul_sr(
...
@@ -304,9 +308,11 @@ def matmul_sr(
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
T
.
annotate_layout
(
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
{
})
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
@@ -360,7 +366,8 @@ def run_gemm_sr(
...
@@ -360,7 +366,8 @@ def run_gemm_sr(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -430,9 +437,9 @@ def matmul_rr(
...
@@ -430,9 +437,9 @@ def matmul_rr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -441,10 +448,12 @@ def matmul_rr(
...
@@ -441,10 +448,12 @@ def matmul_rr(
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
{
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
})
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
}
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
...
@@ -499,7 +508,8 @@ def run_gemm_rr(
...
@@ -499,7 +508,8 @@ def run_gemm_rr(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
View file @
29051439
...
@@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
...
@@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
else
:
else
:
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
A
=
randint_semi_sparse
(
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
M
,
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
)
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
)
else
:
else
:
A
=
randn_semi_sparse
(
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
transposed
=
trans_A
).
to
(
map_torch_type
(
in_dtype
))
M
,
K
,
dtype
=
torch
.
float32
,
device
=
'cuda'
,
B
=
torch
.
randn
((
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
transposed
=
trans_A
).
to
(
map_torch_type
(
in_dtype
))
B
=
torch
.
randn
(
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
'cuda'
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
return
A
,
B
return
A
,
B
...
@@ -69,24 +53,22 @@ def matmul_sp_sm90(
...
@@ -69,24 +53,22 @@ def matmul_sp_sm90(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
'
uint8
'
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
"
uint8
"
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
'
uint8
'
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
"
uint8
"
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
E
:
{
make_cutlass_metadata_layout
(
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
E
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
E_shared
:
}
make_cutlass_metadata_layout
(
)
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"9.0"
,
block_k
=
block_K
),
})
T
.
disable_warp_group_reg_alloc
()
T
.
disable_warp_group_reg_alloc
()
T
.
clear
(
C_frag
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
...
@@ -121,7 +103,7 @@ def matmul_sp_sm80(
...
@@ -121,7 +103,7 @@ def matmul_sp_sm80(
trans_B
,
trans_B
,
):
):
is_8_bit
=
"8"
in
in_dtype
is_8_bit
=
"8"
in
in_dtype
metadata_dtype
=
'
int32
'
if
is_8_bit
else
'
int16
'
metadata_dtype
=
"
int32
"
if
is_8_bit
else
"
int16
"
E_factor
=
SparseTensorCoreIntrinEmitter
.
E_FACTOR_MAP
[
in_dtype
][
metadata_dtype
]
E_factor
=
SparseTensorCoreIntrinEmitter
.
E_FACTOR_MAP
[
in_dtype
][
metadata_dtype
]
A_sparse_shape
=
(
M
,
K
//
2
)
if
not
trans_A
else
(
K
//
2
,
M
)
A_sparse_shape
=
(
M
,
K
//
2
)
if
not
trans_A
else
(
K
//
2
,
M
)
B_shape
=
(
K
,
N
)
if
not
trans_B
else
(
N
,
K
)
B_shape
=
(
K
,
N
)
if
not
trans_B
else
(
N
,
K
)
...
@@ -132,20 +114,22 @@ def matmul_sp_sm80(
...
@@ -132,20 +114,22 @@ def matmul_sp_sm80(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
{
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
@@ -216,7 +200,7 @@ def run_gemm_sp(
...
@@ -216,7 +200,7 @@ def run_gemm_sp(
C
=
_matmul
(
A
,
B
)
C
=
_matmul
(
A
,
B
)
if
'
float8
'
in
in_dtype
:
if
"
float8
"
in
in_dtype
:
diff
=
calc_diff
(
C_sp
,
C
)
diff
=
calc_diff
(
C_sp
,
C
)
assert
diff
<
1e-3
,
f
"
{
diff
=
}
"
assert
diff
<
1e-3
,
f
"
{
diff
=
}
"
else
:
else
:
...
@@ -332,15 +316,11 @@ def test_gemm_sp_sm90():
...
@@ -332,15 +316,11 @@ def test_gemm_sp_sm90():
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
0
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
0
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
2
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
128
,
256
,
2
,
128
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
False
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
True
)
False
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
True
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float8_e4m3"
,
"float16"
,
"float16"
,
64
,
64
,
64
,
2
,
128
,
False
,
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"float8_e4m3"
,
"float16"
,
"float16"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"int8"
,
"int32"
,
"int32"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
run_gemm_sp_sm90
(
512
,
1024
,
768
,
"int8"
,
"int32"
,
"int32"
,
64
,
64
,
64
,
2
,
128
,
False
,
True
)
...
@@ -352,12 +332,9 @@ def test_gemm_sp_sm80():
...
@@ -352,12 +332,9 @@ def test_gemm_sp_sm80():
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
32
,
32
,
64
,
0
,
32
,
False
,
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
32
,
32
,
64
,
0
,
32
,
False
,
True
)
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
32
,
False
,
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
0
,
128
,
False
,
True
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
1
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
1
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
2
,
128
)
run_gemm_sp_sm80
(
512
,
1024
,
768
,
"float16"
,
"float32"
,
"float32"
,
64
,
64
,
64
,
2
,
128
)
...
...
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
View file @
29051439
...
@@ -34,20 +34,22 @@ def matmul(
...
@@ -34,20 +34,22 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
{
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
@@ -80,7 +82,7 @@ def run_gemm_ss(
...
@@ -80,7 +82,7 @@ def run_gemm_ss(
num_stages
=
3
,
num_stages
=
3
,
num_threads
=
128
,
num_threads
=
128
,
):
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul
(
program
=
matmul
(
M
,
M
,
N
,
N
,
...
@@ -105,7 +107,8 @@ def run_gemm_ss(
...
@@ -105,7 +107,8 @@ def run_gemm_ss(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
...
@@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
...
@@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
low
,
high
=
(
0
,
4
)
if
is_unsigned
else
(
-
2
,
2
)
else
:
else
:
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
low
,
high
=
(
0
,
128
)
if
is_unsigned
else
(
-
64
,
64
)
A
=
randint_semi_sparse
(
A
=
randint_semi_sparse
(
M
,
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
M
,
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
)
K
,
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randint
(
size
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
low
=
low
,
high
=
high
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
)
else
:
else
:
A
=
randn_semi_sparse
(
A
=
randn_semi_sparse
(
M
,
K
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
"cuda"
,
transposed
=
trans_A
)
M
,
K
,
dtype
=
map_torch_type
(
in_dtype
),
device
=
'cuda'
,
transposed
=
trans_A
)
B
=
torch
.
randn
((
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
"cuda"
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
B
=
torch
.
randn
(
(
N
,
K
)
if
trans_B
else
(
K
,
N
),
device
=
'cuda'
,
dtype
=
torch
.
float32
).
to
(
map_torch_type
(
in_dtype
))
return
A
,
B
return
A
,
B
...
@@ -184,8 +172,7 @@ def test_gemm_ss():
...
@@ -184,8 +172,7 @@ def test_gemm_ss():
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
,
2
)
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
,
2
)
# float8 tests
# float8 tests
run_gemm_ss
(
128
,
128
,
128
,
False
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
run_gemm_ss
(
128
,
128
,
128
,
False
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
2
)
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
run_gemm_ss
(
128
,
128
,
128
,
True
,
True
,
"float8_e5m2"
,
"float8_e5m2"
,
"float32"
,
128
,
128
,
64
,
2
)
# tfloat32 test
# tfloat32 test
...
@@ -222,10 +209,10 @@ def matmul_rs(
...
@@ -222,10 +209,10 @@ def matmul_rs(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -233,11 +220,13 @@ def matmul_rs(
...
@@ -233,11 +220,13 @@ def matmul_rs(
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
@@ -271,7 +260,7 @@ def run_gemm_rs(
...
@@ -271,7 +260,7 @@ def run_gemm_rs(
num_stages
=
3
,
num_stages
=
3
,
num_threads
=
128
,
num_threads
=
128
,
):
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_rs
(
program
=
matmul_rs
(
M
,
M
,
N
,
N
,
...
@@ -296,7 +285,8 @@ def run_gemm_rs(
...
@@ -296,7 +285,8 @@ def run_gemm_rs(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
@@ -376,10 +366,10 @@ def matmul_sr(
...
@@ -376,10 +366,10 @@ def matmul_sr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -387,11 +377,13 @@ def matmul_sr(
...
@@ -387,11 +377,13 @@ def matmul_sr(
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
E_shared
=
T
.
alloc_shared
((
block_M
,
block_K
//
E_factor
),
metadata_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
{
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
@@ -425,7 +417,7 @@ def run_gemm_sr(
...
@@ -425,7 +417,7 @@ def run_gemm_sr(
num_stages
=
3
,
num_stages
=
3
,
num_threads
=
128
,
num_threads
=
128
,
):
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_sr
(
program
=
matmul_sr
(
M
,
M
,
N
,
N
,
...
@@ -450,7 +442,8 @@ def run_gemm_sr(
...
@@ -450,7 +442,8 @@ def run_gemm_sr(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
@@ -531,10 +524,10 @@ def matmul_rr(
...
@@ -531,10 +524,10 @@ def matmul_rr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
A_sparse
:
T
.
Tensor
(
A_sparse_shape
,
in_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
E
:
T
.
Tensor
((
M
,
K
//
E_factor
),
metadata_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -543,12 +536,14 @@ def matmul_rr(
...
@@ -543,12 +536,14 @@ def matmul_rr(
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
{
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
E
:
make_cutlass_metadata_layout
(
E
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
})
E_shared
:
make_cutlass_metadata_layout
(
E_shared
,
mma_dtype
=
in_dtype
,
arch
=
"8.0"
),
}
)
T
.
clear
(
C_frag
)
T
.
clear
(
C_frag
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
T
.
copy
(
E
[
by
*
block_M
,
k
*
block_K
//
E_factor
],
E_shared
)
...
@@ -583,7 +578,7 @@ def run_gemm_rr(
...
@@ -583,7 +578,7 @@ def run_gemm_rr(
num_stages
=
3
,
num_stages
=
3
,
num_threads
=
128
,
num_threads
=
128
,
):
):
metadata_dtype
=
'
int32
'
if
(
'8'
in
in_dtype
)
else
'
int16
'
metadata_dtype
=
"
int32
"
if
(
"8"
in
in_dtype
)
else
"
int16
"
program
=
matmul_rr
(
program
=
matmul_rr
(
M
,
M
,
N
,
N
,
...
@@ -608,7 +603,8 @@ def run_gemm_rr(
...
@@ -608,7 +603,8 @@ def run_gemm_rr(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A
,
B
=
generate_dense_input
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
A_sparse
,
E
=
compress
(
A
,
transposed
=
trans_A
,
block_k
=
block_K
,
arch
=
"8.0"
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
C_sp
=
kernel
(
A_sparse
,
E
,
B
)
...
...
testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py
View file @
29051439
...
@@ -11,22 +11,14 @@ def _check(original, transformed):
...
@@ -11,22 +11,14 @@ def _check(original, transformed):
mod
=
tl
.
transform
.
Simplify
()(
mod
)
mod
=
tl
.
transform
.
Simplify
()(
mod
)
mod
=
tl
.
transform
.
LowerOpaqueBlock
()(
mod
)
mod
=
tl
.
transform
.
LowerOpaqueBlock
()(
mod
)
mod
=
tl
.
transform
.
Simplify
()(
mod
)
mod
=
tl
.
transform
.
Simplify
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
True
)
def
test_trival_pipeline
():
def
test_trival_pipeline
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
16
,
1
),
"float32"
),
C
:
T
.
Tensor
((
16
,
1
),
"float32"
)):
def
before
(
A
:
T
.
Tensor
((
16
,
1
),
"float32"
),
C
:
T
.
Tensor
((
16
,
1
),
"float32"
)):
for
tx
in
T
.
thread_binding
(
0
,
16
,
thread
=
"threadIdx.x"
):
for
tx
in
T
.
thread_binding
(
0
,
16
,
thread
=
"threadIdx.x"
):
for
i
in
T
.
serial
(
for
i
in
T
.
serial
(
0
,
1
,
annotations
=
{
"software_pipeline_stage"
:
[
0
,
1
],
"software_pipeline_order"
:
[
0
,
1
]}):
0
,
1
,
annotations
=
{
"software_pipeline_stage"
:
[
0
,
1
],
"software_pipeline_order"
:
[
0
,
1
]
}):
with
T
.
block
():
with
T
.
block
():
T
.
reads
(
A
[
tx
,
i
])
T
.
reads
(
A
[
tx
,
i
])
T
.
writes
(
C
[
tx
,
i
])
T
.
writes
(
C
[
tx
,
i
])
...
...
testing/python/transform/test_tilelang_transform_cluster_planning.py
View file @
29051439
...
@@ -21,10 +21,8 @@ def _check(original, transformed):
...
@@ -21,10 +21,8 @@ def _check(original, transformed):
def
test_cluster_planning
():
def
test_cluster_planning
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
(
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float16"
)):
(
1024
,
1024
),
"float16"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float16"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float16"
)
...
@@ -41,8 +39,7 @@ def test_cluster_planning():
...
@@ -41,8 +39,7 @@ def test_cluster_planning():
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
@
T
.
prim_func
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
(
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float16"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float16"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float16"
)):
(
1024
,
1024
),
"float16"
)):
T
.
func_attr
({
"clusterIdx.y"
:
T
.
int32
(
2
)})
T
.
func_attr
({
"clusterIdx.y"
:
T
.
int32
(
2
)})
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float16"
)
...
...
testing/python/transform/test_tilelang_transform_config_index_bitwidth.py
View file @
29051439
...
@@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N
=
64
block_N
=
64
num_stages
=
0
num_stages
=
0
threads
=
128
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
batch
=
T
.
int32
(
batch
)
batch
=
T
.
int32
(
batch
)
heads
=
T
.
int32
(
heads
)
heads
=
T
.
int32
(
heads
)
...
@@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
@
T
.
macro
def
MMA0
(
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
...
@@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
def
MMA1
(
def
MMA1
(
V
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
V_shared
:
T
.
Tensor
([
block_M
,
dim
],
dtype
),
V_shared
:
T
.
Tensor
([
block_M
,
dim
],
dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
k
:
T
.
int32
,
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
def
Softmax
(
def
Softmax
(
acc_s
:
T
.
Tensor
([
block_M
,
block_N
],
accum_dtype
),
acc_s
:
T
.
Tensor
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
acc_s_cast
:
T
.
Tensor
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_max
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
logsum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
logsum
:
T
.
Tensor
([
block_M
],
accum_dtype
),
):
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@
T
.
macro
@
T
.
macro
def
Rescale
(
def
Rescale
(
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
acc_o
:
T
.
Tensor
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
Tensor
([
block_M
],
accum_dtype
),
):
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
@@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
main
return
main
...
...
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
View file @
29051439
...
@@ -22,7 +22,6 @@ def _check(original, transformed):
...
@@ -22,7 +22,6 @@ def _check(original, transformed):
def
test_lower_fence_proxy
():
def
test_lower_fence_proxy
():
@
T
.
prim_func
@
T
.
prim_func
def
before
():
def
before
():
with
T
.
Kernel
(
8
):
with
T
.
Kernel
(
8
):
...
@@ -30,12 +29,15 @@ def test_lower_fence_proxy():
...
@@ -30,12 +29,15 @@ def test_lower_fence_proxy():
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
for
i
in
T
.
unroll
(
16
):
for
i
in
T
.
unroll
(
16
):
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
T
.
call_intrin
(
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
"handle"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
@
T
.
prim_func
def
after
():
def
after
():
...
@@ -44,19 +46,21 @@ def test_lower_fence_proxy():
...
@@ -44,19 +46,21 @@ def test_lower_fence_proxy():
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
B_shared
=
T
.
decl_buffer
((
1
,
4
,
512
),
"float16"
,
scope
=
"shared.dyn"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
C_local
=
T
.
decl_buffer
((
32
,),
scope
=
"local"
)
for
i
in
T
.
unroll
(
16
):
for
i
in
T
.
unroll
(
16
):
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
C_local
[
i
*
2
:
i
*
2
+
2
]
=
T
.
Broadcast
(
T
.
float32
(
0
),
2
)
T
.
fence_proxy_async
()
T
.
fence_proxy_async
()
T
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
T
.
call_intrin
(
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
"handle"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
tir
.
op
.
Op
.
get
(
"tl.tl_gemm"
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
_check
(
before
,
after
)
_check
(
before
,
after
)
def
test_async_to_generic_no_double_fence
():
def
test_async_to_generic_no_double_fence
():
@
T
.
prim_func
@
T
.
prim_func
def
before
():
def
before
():
with
T
.
Kernel
(
8
):
with
T
.
Kernel
(
8
):
...
@@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence():
...
@@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence():
def
test_proxy_hint_override
():
def
test_proxy_hint_override
():
@
T
.
prim_func
@
T
.
prim_func
def
before
():
def
before
():
with
T
.
Kernel
(
8
):
with
T
.
Kernel
(
8
):
...
@@ -123,7 +126,6 @@ def test_proxy_hint_override():
...
@@ -123,7 +126,6 @@ def test_proxy_hint_override():
def
test_tma_store_sync_injection
():
def
test_tma_store_sync_injection
():
@
T
.
prim_func
@
T
.
prim_func
def
before
():
def
before
():
with
T
.
Kernel
(
8
):
with
T
.
Kernel
(
8
):
...
@@ -154,7 +156,6 @@ def test_tma_store_sync_injection():
...
@@ -154,7 +156,6 @@ def test_tma_store_sync_injection():
def
test_wgmma_marked_async
():
def
test_wgmma_marked_async
():
@
T
.
prim_func
@
T
.
prim_func
def
before
():
def
before
():
with
T
.
Kernel
(
1
):
with
T
.
Kernel
(
1
):
...
@@ -164,9 +165,24 @@ def test_wgmma_marked_async():
...
@@ -164,9 +165,24 @@ def test_wgmma_marked_async():
C_local
=
T
.
decl_buffer
((
32
,),
"float16"
,
scope
=
"local"
)
C_local
=
T
.
decl_buffer
((
32
,),
"float16"
,
scope
=
"local"
)
A_shared
[
0
]
=
T
.
float16
(
0
)
A_shared
[
0
]
=
T
.
float16
(
0
)
T
.
warpgroup_arrive
()
T
.
warpgroup_arrive
()
T
.
ptx_wgmma_ss
(
"float16"
,
"m64n64k16"
,
T
.
bool
(
True
),
T
.
bool
(
True
),
"fp16"
,
"fp16"
,
T
.
ptx_wgmma_ss
(
"fp16"
,
desc_a
.
data
,
T
.
int32
(
0
),
desc_b
.
data
,
T
.
int32
(
0
),
C_local
.
data
,
"float16"
,
T
.
int32
(
0
),
T
.
bool
(
True
),
1
,
1
)
"m64n64k16"
,
T
.
bool
(
True
),
T
.
bool
(
True
),
"fp16"
,
"fp16"
,
"fp16"
,
desc_a
.
data
,
T
.
int32
(
0
),
desc_b
.
data
,
T
.
int32
(
0
),
C_local
.
data
,
T
.
int32
(
0
),
T
.
bool
(
True
),
1
,
1
,
)
mod
=
tvm
.
IRModule
.
from_expr
(
before
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tvm
.
IRModule
.
from_expr
(
before
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
mod
)
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
mod
)
...
...
testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py
View file @
29051439
...
@@ -35,26 +35,25 @@ def test_inject_set_max_nreg():
...
@@ -35,26 +35,25 @@ def test_inject_set_max_nreg():
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
+
3
),
T
.
bitwise_xor
(
k
//
3
%
2
,
1
))
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
+
3
),
T
.
bitwise_xor
(
k
//
3
%
2
,
1
))
if
v
-
128
==
0
:
if
v
-
128
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
0
,
2
,
2
,
0
),
T
.
get_mbarrier
(
k
%
3
),
T
.
get_mbarrier
(
k
%
3
),
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
)
,
k
*
3
2
,
k
*
32
,
by
*
64
)
by
*
64
,
T
.
evaluate
(
)
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
)]))
else
:
else
:
# Consumer branch - should have set_max_nreg(240, 1)
# Consumer branch - should have set_max_nreg(240, 1)
for
k
in
range
(
16
):
for
k
in
range
(
16
):
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
mbarrier_wait_parity
(
T
.
get_mbarrier
(
k
%
3
),
k
//
3
%
2
)
T
.
call_extern
(
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
"handle"
,
T
.
tvm_access_ptr
(
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
)
T
.
evaluate
(
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
tir
.
Call
(
"handle"
,
"tir.ptx_arrive_barrier"
,
[
T
.
get_mbarrier
(
k
%
3
+
3
)]))
# Apply the InjectSetMaxNReg pass
# Apply the InjectSetMaxNReg pass
func
=
before
func
=
before
...
@@ -67,15 +66,18 @@ def test_inject_set_max_nreg():
...
@@ -67,15 +66,18 @@ def test_inject_set_max_nreg():
set_max_nreg_calls
=
[]
set_max_nreg_calls
=
[]
def
collect_set_max_nreg
(
stmt
):
def
collect_set_max_nreg
(
stmt
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
'op'
)
and
if
(
hasattr
(
stmt
.
value
.
op
,
'name'
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
"op"
)
and
hasattr
(
stmt
.
value
.
op
,
"name"
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
set_max_nreg_calls
.
append
(
stmt
.
value
)
set_max_nreg_calls
.
append
(
stmt
.
value
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
assert
len
(
set_max_nreg_calls
assert
len
(
set_max_nreg_calls
)
>=
2
,
f
"Expected at least 2 set_max_nreg calls, got
{
len
(
set_max_nreg_calls
)
}
"
)
>=
2
,
f
"Expected at least 2 set_max_nreg calls, got
{
len
(
set_max_nreg_calls
)
}
"
print
(
"InjectSetMaxNReg test passed!"
)
print
(
"InjectSetMaxNReg test passed!"
)
...
@@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg():
...
@@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg():
set_max_nreg_calls
=
[]
set_max_nreg_calls
=
[]
def
collect_set_max_nreg
(
stmt
):
def
collect_set_max_nreg
(
stmt
):
if
(
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
'op'
)
and
if
(
hasattr
(
stmt
.
value
.
op
,
'name'
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
isinstance
(
stmt
,
tvm
.
tir
.
Evaluate
)
and
hasattr
(
stmt
.
value
,
"op"
)
and
hasattr
(
stmt
.
value
.
op
,
"name"
)
and
stmt
.
value
.
op
.
name
==
"tl.set_max_nreg"
):
set_max_nreg_calls
.
append
(
stmt
.
value
)
set_max_nreg_calls
.
append
(
stmt
.
value
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
tvm
.
tir
.
stmt_functor
.
post_order_visit
(
main_func
.
body
,
collect_set_max_nreg
)
# Should have no set_max_nreg calls when no_set_max_nreg is present
# Should have no set_max_nreg calls when no_set_max_nreg is present
assert
len
(
assert
len
(
set_max_nreg_calls
)
==
0
,
f
"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got
{
len
(
set_max_nreg_calls
)
}
"
set_max_nreg_calls
)
==
0
,
f
"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got
{
len
(
set_max_nreg_calls
)
}
"
print
(
"InjectSetMaxNReg with no_set_max_nreg test passed!"
)
print
(
"InjectSetMaxNReg with no_set_max_nreg test passed!"
)
...
...
testing/python/transform/test_tilelang_transform_layout_inference.py
View file @
29051439
...
@@ -8,17 +8,21 @@ import pytest
...
@@ -8,17 +8,21 @@ import pytest
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
@
pytest
.
mark
.
parametrize
(
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
])
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
],
)
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
N
=
tvm
.
te
.
var
(
"n"
)
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
K
=
tvm
.
te
.
var
(
"k"
)
def
before
():
def
before
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
...
@@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...
@@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
t
=
thread_bindings
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
for
vec
in
T
.
Parallel
(
vec_load_b
):
for
vec
in
T
.
Parallel
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
B_shared
[
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
]
=
T
.
if_then_else
(
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
B
[
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
],
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
T
.
float16
(
0
))
],
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
def
after
():
def
after
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
t
=
thread_bindings
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
for
vec
in
T
.
vectorized
(
vec_load_b
):
for
vec
in
T
.
vectorized
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
B_shared
[
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
k
*
block_K
+
i
*
]
=
T
.
if_then_else
(
(
threads
*
vec_load_b
//
block_N
)
+
t
//
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
],
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
T
.
float16
(
0
),
)
else
:
else
:
for
vec
in
T
.
serial
(
vec_load_b
):
for
vec
in
T
.
serial
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
B_shared
[
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
k
*
block_K
+
i
*
]
=
T
.
if_then_else
(
(
threads
*
vec_load_b
//
block_N
)
+
t
//
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
],
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
with
tvm
.
target
.
Target
(
auto_target
):
with
tvm
.
target
.
Target
(
auto_target
):
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
...
...
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
View file @
29051439
...
@@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
...
@@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
dtype
=
"float32"
dtype
=
"float32"
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
...
@@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
...
@@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
A_shared
[
tid
,
j
]
=
A
[
tid
+
M_offset
,
j
+
N_offset
]
A_shared
[
tid
,
j
]
=
A
[
tid
+
M_offset
,
j
+
N_offset
]
@
T
.
prim_func
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
for
j
in
T
.
serial
(
N
):
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
j
+
N_offset
<
N
,
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
)
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
)
T
.
float32
(
0
)),
T
.
float32
(
0
))
return
main
,
expected
return
main
,
expected
...
@@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
...
@@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def
issue_1013_buggy_kernel
():
def
issue_1013_buggy_kernel
():
# NOTE: This kernel is mainly to test some corner cases in boundary check
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens
=
T
.
dynamic
(
'
num_tokens
'
)
num_tokens
=
T
.
dynamic
(
"
num_tokens
"
)
num_threads
=
128
num_threads
=
128
@
T
.
prim_func
@
T
.
prim_func
def
main
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
def
main
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
count
=
T
.
alloc_var
(
'
int
'
)
count
=
T
.
alloc_var
(
"
int
"
)
thread_idx
=
T
.
get_thread_binding
()
thread_idx
=
T
.
get_thread_binding
()
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
idx
=
thread_idx
+
i
*
num_threads
idx
=
thread_idx
+
i
*
num_threads
...
@@ -59,24 +62,22 @@ def issue_1013_buggy_kernel():
...
@@ -59,24 +62,22 @@ def issue_1013_buggy_kernel():
@
T
.
prim_func
@
T
.
prim_func
def
expected
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
def
expected
(
x
:
T
.
Tensor
((
num_tokens
,),
dtype
=
"int64"
)):
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
num_threads
)
as
_
:
count
=
T
.
alloc_var
(
'
int
'
)
count
=
T
.
alloc_var
(
"
int
"
)
thread_idx
=
T
.
get_thread_binding
()
thread_idx
=
T
.
get_thread_binding
()
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
for
i
in
T
.
serial
(
0
,
T
.
ceildiv
(
num_tokens
-
thread_idx
,
num_threads
)):
idx
=
thread_idx
+
i
*
num_threads
idx
=
thread_idx
+
i
*
num_threads
count
+=
T
.
Cast
(
"int32"
,
count
+=
T
.
Cast
(
"int32"
,
T
.
if_then_else
(
idx
<
num_tokens
,
x
[
idx
],
T
.
int64
(
0
))
==
T
.
int64
(
2
))
T
.
if_then_else
(
idx
<
num_tokens
,
x
[
idx
],
T
.
int64
(
0
))
==
T
.
int64
(
2
))
return
main
,
expected
return
main
,
expected
def
vectorize_access_with_atmoic_add_legalize
(
M
:
int
=
64
,
def
vectorize_access_with_atmoic_add_legalize
(
M
:
int
=
64
,
N
:
int
=
64
,
M_offset
:
int
=
2
,
N_offset
:
int
=
2
):
N
:
int
=
64
,
M_offset
:
int
=
2
,
N_offset
:
int
=
2
):
dtype
=
"float32"
dtype
=
"float32"
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
...
@@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64,
...
@@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64,
T
.
atomic_add
(
A
[
tid
+
M_offset
,
j
+
N_offset
],
1
)
T
.
atomic_add
(
A
[
tid
+
M_offset
,
j
+
N_offset
],
1
)
@
T
.
prim_func
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
reads
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
for
j
in
T
.
serial
(
N
):
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
A_shared
[
tid
,
j
]
=
T
.
if_then_else
(
j
+
N_offset
<
N
,
j
+
N_offset
<
N
,
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
T
.
float32
(
0
)),
T
.
float32
(
0
)
T
.
if_then_else
(
tid
+
M_offset
<
M
,
A
[
tid
+
M_offset
,
j
+
N_offset
],
)
T
.
float32
(
0
)),
T
.
float32
(
0
))
# Nest if-then-else is expected, do not flatten it to pass structural equal check
# Nest if-then-else is expected, do not flatten it to pass structural equal check
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
tid
+
M_offset
<
M
:
if
tid
+
M_offset
<
M
:
...
@@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in
...
@@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in
dtype
=
"float32"
dtype
=
"float32"
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
for
j
in
T
.
serial
(
N
):
for
j
in
T
.
serial
(
N
):
A
[
tid
+
M_offset
,
j
+
N_offset
]
=
1
A
[
tid
+
M_offset
,
j
+
N_offset
]
=
1
@
T
.
prim_func
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
=
dtype
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
T
.
writes
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
T
.
writes
(
A
[
tid
+
M_offset
,
N_offset
:
N
+
N_offset
])
for
j
in
T
.
serial
(
N
):
for
j
in
T
.
serial
(
N
):
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
j
+
N_offset
<
N
:
# noqa: SIM102
if
tid
+
M_offset
<
M
:
if
tid
+
M_offset
<
M
:
...
...
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py
View file @
29051439
...
@@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
...
@@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
vec_len
=
8
vec_len
=
8
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),):
def
main
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
...
@@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
...
@@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
A_shared
[
tid
,
j
,
v
]
=
A
[
tid
,
j
,
v
]
A_shared
[
tid
,
j
,
v
]
=
A
[
tid
,
j
,
v
]
@
T
.
prim_func
@
T
.
prim_func
def
expected
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),):
def
expected
(
A
:
T
.
Tensor
((
M
,
N
,
vec_len
),
dtype
=
"float32"
),
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
with
T
.
Kernel
(
1
,
1
,
threads
=
M
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
A_shared
=
T
.
alloc_shared
((
M
,
N
,
vec_len
),
dtype
=
dtype
)
tid
=
T
.
get_thread_binding
()
tid
=
T
.
get_thread_binding
()
...
...
testing/python/transform/test_tilelang_transform_let_inline.py
View file @
29051439
...
@@ -8,12 +8,10 @@ def _check(original, transformed):
...
@@ -8,12 +8,10 @@ def _check(original, transformed):
func
=
original
func
=
original
mod
=
tvm
.
IRModule
.
from_expr
(
func
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tvm
.
IRModule
.
from_expr
(
func
.
with_attr
(
"global_symbol"
,
"main"
))
mod
=
tl
.
transform
.
LetInline
()(
mod
)
mod
=
tl
.
transform
.
LetInline
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
tvm
.
ir
.
assert_structural_equal
(
mod
[
"main"
],
transformed
.
with_attr
(
"global_symbol"
,
"main"
),
True
)
True
)
def
test_let_binding
():
def
test_let_binding
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
)):
def
before
(
A
:
T
.
Tensor
((
128
,
128
),
"float32"
),
B
:
T
.
Tensor
((
128
,
128
),
"float32"
)):
for
i
in
range
(
128
):
for
i
in
range
(
128
):
...
@@ -34,7 +32,6 @@ def test_let_binding():
...
@@ -34,7 +32,6 @@ def test_let_binding():
def
test_parallel_scope
():
def
test_parallel_scope
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
128
,),
"float32"
)):
def
before
(
A
:
T
.
Tensor
((
128
,),
"float32"
)):
for
i
in
T
.
Parallel
(
128
):
for
i
in
T
.
Parallel
(
128
):
...
...
testing/python/transform/test_tilelang_transform_lower_hopper_intrin.py
View file @
29051439
...
@@ -24,7 +24,6 @@ def _check(original, transformed):
...
@@ -24,7 +24,6 @@ def _check(original, transformed):
def
test_lower_hopper_intrin_barrier
():
def
test_lower_hopper_intrin_barrier
():
@
T
.
prim_func
@
T
.
prim_func
def
before
():
def
before
():
with
T
.
Kernel
(
8
):
with
T
.
Kernel
(
8
):
...
@@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier():
...
@@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier():
v_1
=
T
.
launch_thread
(
"threadIdx.x"
,
128
)
v_1
=
T
.
launch_thread
(
"threadIdx.x"
,
128
)
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.create_barriers"
,
[
4
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.create_barriers"
,
[
4
]))
with
T
.
If
(
v_1
==
0
),
T
.
Then
():
with
T
.
If
(
v_1
==
0
),
T
.
Then
():
T
.
evaluate
(
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
0
),
128
]))
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
1
),
128
]))
[
T
.
get_mbarrier
(
0
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
2
),
128
]))
T
.
evaluate
(
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
3
),
128
]))
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
1
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
2
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.ptx_init_barrier_thread_count"
,
[
T
.
get_mbarrier
(
3
),
128
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.tvm_storage_sync"
,
[
"shared"
]))
T
.
evaluate
(
tir
.
Call
(
"handle"
,
"tir.tvm_storage_sync"
,
[
"shared"
]))
_check
(
before
,
after
)
_check
(
before
,
after
)
...
...
testing/python/transform/test_tilelang_transform_lower_tile_op.py
View file @
29051439
...
@@ -8,63 +8,69 @@ import pytest
...
@@ -8,63 +8,69 @@ import pytest
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
auto_target
=
tvm
.
target
.
Target
(
determine_target
(
"auto"
))
@
pytest
.
mark
.
parametrize
(
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
[
@
pytest
.
mark
.
parametrize
(
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
"block_M, block_N, block_K, threads, vec_load_b, dtype"
,
])
[
(
64
,
64
,
32
,
128
,
8
,
"float16"
),
],
)
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
def
test_loop_tail_split
(
block_M
,
block_N
,
block_K
,
threads
,
vec_load_b
,
dtype
):
N
=
tvm
.
te
.
var
(
"n"
)
N
=
tvm
.
te
.
var
(
"n"
)
K
=
tvm
.
te
.
var
(
"k"
)
K
=
tvm
.
te
.
var
(
"k"
)
def
before
():
def
before
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
def
after
():
def
after
():
@
T
.
prim_func
@
T
.
prim_func
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),):
def
main
(
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
threads
)
as
(
bx
):
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
thread_bindings
=
T
.
thread_binding
(
0
,
threads
,
"threadIdx.x"
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
t
=
thread_bindings
t
=
thread_bindings
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
for
i
in
T
.
unroll
(
0
,
block_N
*
block_K
//
(
threads
*
vec_load_b
)):
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
if
(
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
(
block_N
//
vec_load_b
))
*
N
%
vec_load_b
==
0
:
for
vec
in
T
.
vectorized
(
vec_load_b
):
for
vec
in
T
.
vectorized
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
B_shared
[
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
k
*
block_K
+
i
*
]
=
T
.
if_then_else
(
(
threads
*
vec_load_b
//
block_N
)
+
t
//
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
],
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
T
.
float16
(
0
),
)
else
:
else
:
for
vec
in
T
.
serial
(
vec_load_b
):
for
vec
in
T
.
serial
(
vec_load_b
):
B_shared
[
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
B_shared
[
(
block_N
//
vec_load_b
),
t
%
(
block_N
//
vec_load_b
)
*
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
(
block_N
//
vec_load_b
)
+
vec
]
=
T
.
if_then_else
(
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
k
*
block_K
+
i
*
]
=
T
.
if_then_else
(
(
threads
*
vec_load_b
//
block_N
)
+
t
//
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
)
<
K
(
block_N
//
vec_load_b
)
<
K
and
bx
*
block_N
+
t
%
and
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
<
N
,
B
[
B
[
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
k
*
block_K
+
i
*
(
threads
*
vec_load_b
//
block_N
)
+
t
//
(
block_N
//
vec_load_b
),
t
//
(
block_N
//
vec_load_b
),
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
(
block_N
//
vec_load_b
)
+
vec
,
bx
*
block_N
+
t
%
(
block_N
//
vec_load_b
)
*
],
(
block_N
//
vec_load_b
)
+
vec
],
T
.
float16
(
0
))
T
.
float16
(
0
),
)
return
tvm
.
IRModule
({
'
main
'
:
main
})
return
tvm
.
IRModule
({
"
main
"
:
main
})
with
tvm
.
transform
.
PassContext
():
with
tvm
.
transform
.
PassContext
():
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
mod
=
tvm
.
tir
.
transform
.
BindTarget
(
auto_target
)(
before
())
...
...
testing/python/transform/test_tilelang_transform_make_packed_api.py
View file @
29051439
...
@@ -80,7 +80,6 @@ def test_target_host_removed():
...
@@ -80,7 +80,6 @@ def test_target_host_removed():
@
I
.
ir_module
@
I
.
ir_module
class
before
:
class
before
:
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"cuda"
,
host
=
host
)})
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"cuda"
,
host
=
host
)})
...
@@ -102,7 +101,6 @@ def test_internal_subroutine_call():
...
@@ -102,7 +101,6 @@ def test_internal_subroutine_call():
@
I
.
ir_module
@
I
.
ir_module
class
before
:
class
before
:
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
T
.
func_attr
({
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
...
@@ -121,7 +119,8 @@ def test_internal_subroutine_call():
...
@@ -121,7 +119,8 @@ def test_internal_subroutine_call():
subroutine_call_op
=
compute_scope
.
body
.
value
.
op
subroutine_call_op
=
compute_scope
.
body
.
value
.
op
assert
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
GlobalVar
),
(
assert
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
GlobalVar
),
(
f
"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f
"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
def
test_subroutine_call_to_externally_visible_subroutine
():
def
test_subroutine_call_to_externally_visible_subroutine
():
...
@@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine():
...
@@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine():
@
I
.
ir_module
@
I
.
ir_module
class
before
:
class
before
:
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
def
main
(
A
:
T
.
Buffer
(
1
,
"float32"
)):
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
T
.
func_attr
({
"global_symbol"
:
"main"
,
"target"
:
T
.
target
(
"llvm"
,
host
=
"llvm"
)})
...
@@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine():
...
@@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine():
assert
subroutine_compute_scope
is
not
None
assert
subroutine_compute_scope
is
not
None
subroutine_call_op
=
main_compute_scope
.
body
.
value
.
op
subroutine_call_op
=
main_compute_scope
.
body
.
value
.
op
assert
(
assert
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
Op
)
and
subroutine_call_op
.
name
==
"tir.tvm_call_cpacked"
,
(
isinstance
(
subroutine_call_op
,
tvm
.
ir
.
Op
)
and
f
"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
subroutine_call_op
.
name
==
"tir.tvm_call_cpacked"
f
"but instead has an operation of type
{
subroutine_call_op
}
"
),
(
f
"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
)
f
"but instead has an operation of type
{
subroutine_call_op
}
"
)
@
tilelang
.
testing
.
requires_llvm
@
tilelang
.
testing
.
requires_llvm
...
@@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count():
...
@@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count():
@
T
.
prim_func
@
T
.
prim_func
def
func
(
def
func
(
A
:
T
.
Buffer
([
16
,
16
],
"int32"
),
A
:
T
.
Buffer
([
16
,
16
],
"int32"
),
B
:
T
.
Buffer
([
16
,
16
],
"int32"
),
B
:
T
.
Buffer
([
16
,
16
],
"int32"
),
C
:
T
.
Buffer
([
16
,
16
],
"int32"
),
C
:
T
.
Buffer
([
16
,
16
],
"int32"
),
D
:
T
.
Buffer
([
16
,
16
],
"int32"
),
D
:
T
.
Buffer
([
16
,
16
],
"int32"
),
):
):
pass
pass
...
...
testing/python/transform/test_tilelang_transform_multi_version_buffer.py
View file @
29051439
...
@@ -31,7 +31,6 @@ block_K = 32
...
@@ -31,7 +31,6 @@ block_K = 32
def
test_multi_version_buffer
():
def
test_multi_version_buffer
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
def
before
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
bx
=
T
.
launch_thread
(
"blockIdx.x"
,
8
)
...
@@ -49,21 +48,27 @@ def test_multi_version_buffer():
...
@@ -49,21 +48,27 @@ def test_multi_version_buffer():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
if
v
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
2
,
0
),
0
,
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
2
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
2
),
k
*
32
,
by
*
64
)
k
*
32
,
by
*
64
,
)
if
v
==
0
:
if
v
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
2
,
0
),
0
,
0
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
2
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
2
),
bx
*
64
,
k
*
32
)
bx
*
64
,
k
*
32
,
)
T
.
call_extern
(
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
0
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
))
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
),
)
@
T
.
prim_func
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
def
after
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
)):
...
@@ -82,31 +87,32 @@ def test_multi_version_buffer():
...
@@ -82,31 +87,32 @@ def test_multi_version_buffer():
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
for
k
in
T
.
serial
(
16
,
annotations
=
{
"num_stages"
:
T
.
int32
(
3
)}):
if
v
==
0
:
if
v
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
T
.
create_tma_descriptor
(
6
,
2
,
A
.
data
,
512
,
512
,
2
,
1024
,
32
,
64
,
1
,
1
,
0
,
2
,
2
,
0
),
2
,
0
),
0
,
0
,
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
k
*
32
,
k
*
32
,
by
*
64
)
by
*
64
,
)
if
v
==
0
:
if
v
==
0
:
T
.
tma_load
(
T
.
tma_load
(
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
T
.
create_tma_descriptor
(
6
,
2
,
B
.
data
,
512
,
512
,
2
,
1024
,
64
,
32
,
1
,
1
,
0
,
3
,
2
,
0
),
2
,
0
),
0
,
0
,
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
2
),
bx
*
64
,
bx
*
64
,
k
*
32
)
k
*
32
,
)
T
.
call_extern
(
T
.
call_extern
(
"handle"
,
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
"handle"
,
T
.
tvm_access_ptr
(
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>"
,
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
A_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float16"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
type_annotation
(
"float
16
"
),
B_shared
.
data
,
k
%
3
*
2048
,
2048
,
1
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float
32
"
),
C_local
.
data
,
0
,
32
,
3
),
T
.
tvm_access_ptr
(
T
.
type_annotation
(
"float32"
),
C_local
.
data
,
0
,
32
,
3
)
)
)
_check
(
before
,
after
)
_check
(
before
,
after
)
def
test_multi_version_buffer_with_let
():
def
test_multi_version_buffer_with_let
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
scales
:
T
.
Tensor
((
4
,),
"float32"
)):
def
before
(
scales
:
T
.
Tensor
((
4
,),
"float32"
)):
with
T
.
block
(
"root"
):
with
T
.
block
(
"root"
):
...
...
testing/python/transform/test_tilelang_transform_pipeline_planning.py
View file @
29051439
...
@@ -19,10 +19,8 @@ def _check(original, transformed):
...
@@ -19,10 +19,8 @@ def _check(original, transformed):
def
test_simple_pipeline
():
def
test_simple_pipeline
():
@
T
.
prim_func
@
T
.
prim_func
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
(
def
before
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float32"
)):
(
1024
,
1024
),
"float32"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
...
@@ -39,8 +37,7 @@ def test_simple_pipeline():
...
@@ -39,8 +37,7 @@ def test_simple_pipeline():
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
T
.
copy
(
C_local
,
C
[
by
*
128
,
bx
*
128
])
@
T
.
prim_func
@
T
.
prim_func
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
(
def
after
(
A
:
T
.
Tensor
((
1024
,
32
),
"float32"
),
B
:
T
.
Tensor
((
32
,
1024
),
"float32"
),
C
:
T
.
Tensor
((
1024
,
1024
),
"float32"
)):
(
1024
,
1024
),
"float32"
)):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
8
,
8
,
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
A_shared
=
T
.
alloc_shared
((
128
,
32
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
B_shared
=
T
.
alloc_shared
((
32
,
128
),
"float32"
)
...
@@ -49,14 +46,13 @@ def test_simple_pipeline():
...
@@ -49,14 +46,13 @@ def test_simple_pipeline():
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
serial
(
for
ko
in
T
.
serial
(
32
,
32
,
annotations
=
{
annotations
=
{
"software_pipeline_async_stages"
:
[
T
.
int32
(
0
)],
"software_pipeline_async_stages"
:
[
T
.
int32
(
0
)],
"software_pipeline_order"
:
[
T
.
int32
(
0
),
T
.
int32
(
1
),
"software_pipeline_order"
:
[
T
.
int32
(
0
),
T
.
int32
(
1
),
T
.
int32
(
2
)],
T
.
int32
(
2
)],
"software_pipeline_stage"
:
[
T
.
int32
(
3
),
T
.
int32
(
3
),
T
.
int32
(
3
)],
"software_pipeline_stage"
:
[
T
.
int32
(
3
),
T
.
int32
(
3
),
},
T
.
int32
(
3
)]
):
}):
T
.
copy
(
A
[
by
*
128
,
ko
*
32
],
A_shared
)
T
.
copy
(
A
[
by
*
128
,
ko
*
32
],
A_shared
)
T
.
copy
(
B
[
ko
*
32
,
bx
*
128
],
B_shared
)
T
.
copy
(
B
[
ko
*
32
,
bx
*
128
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
...
...
testing/python/transform/test_tilelang_transform_simplify.py
View file @
29051439
...
@@ -8,14 +8,13 @@ def modify(
...
@@ -8,14 +8,13 @@ def modify(
with_B
:
bool
=
False
,
with_B
:
bool
=
False
,
with_bias
:
bool
=
False
,
with_bias
:
bool
=
False
,
):
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
64
,
64
)),
A
:
T
.
Tensor
((
64
,
64
)),
B
:
T
.
Tensor
((
64
,
64
)),
B
:
T
.
Tensor
((
64
,
64
)),
C
:
T
.
Tensor
((
64
,
64
)),
C
:
T
.
Tensor
((
64
,
64
)),
D
:
T
.
Tensor
((
64
,
64
)),
D
:
T
.
Tensor
((
64
,
64
)),
bias
:
T
.
Tensor
((
64
,
64
)),
bias
:
T
.
Tensor
((
64
,
64
)),
):
):
if
with_B
:
if
with_B
:
if
with_bias
:
if
with_bias
:
...
@@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False):
...
@@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
a
:
T
.
handle
,
a
:
T
.
handle
,
...
@@ -76,6 +74,7 @@ def test_matmul():
...
@@ -76,6 +74,7 @@ def test_matmul():
kernel
=
tl
.
compile
(
mod
[
"main"
],
out_idx
=
[
2
])
kernel
=
tl
.
compile
(
mod
[
"main"
],
out_idx
=
[
2
])
import
torch
import
torch
a
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
a
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
b
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
b
=
torch
.
randn
(
1024
,
1024
,
dtype
=
torch
.
float16
).
cuda
().
half
()
c
=
kernel
(
a
,
b
)
c
=
kernel
(
a
,
b
)
...
...
testing/python/transform/test_tilelang_transform_thread_sync.py
View file @
29051439
...
@@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc):
...
@@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc):
cuda_target
=
tvm
.
target
.
Target
(
"cuda"
,
host
=
"llvm"
)
cuda_target
=
tvm
.
target
.
Target
(
"cuda"
,
host
=
"llvm"
)
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
({
mod
=
tvm
.
tir
.
transform
.
Apply
(
lambda
f
:
f
.
with_attr
({
"global_symbol"
:
"test"
,
"target"
:
cuda_target
}))(
mod
)
"global_symbol"
:
"test"
,
"target"
:
cuda_target
}))(
mod
)
mod
=
tvm
.
tir
.
transform
.
AnnotateDeviceRegions
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
AnnotateDeviceRegions
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
SplitHostDevice
()(
mod
)
mod
=
tvm
.
tir
.
transform
.
SplitHostDevice
()(
mod
)
...
@@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc):
...
@@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc):
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
def
test_sync_if_with_same_index
():
def
test_sync_if_with_same_index
():
@
T
.
prim_func
(
check_well_formed
=
False
)
@
T
.
prim_func
(
check_well_formed
=
False
)
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
...
@@ -47,7 +42,6 @@ def test_sync_if_with_same_index():
...
@@ -47,7 +42,6 @@ def test_sync_if_with_same_index():
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
def
test_sync_read_thread_id_independent_location
():
def
test_sync_read_thread_id_independent_location
():
@
T
.
prim_func
@
T
.
prim_func
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
def
func
(
p0_arg
:
T
.
Buffer
((
1
,
2
,
1
,
1
),
"float32"
),
p1
:
T
.
Buffer
(
2
,
"float32"
))
->
None
:
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
threadIdx_x
=
T
.
env_thread
(
"threadIdx.x"
)
...
@@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location():
...
@@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location():
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
def
test_sync_shared
():
def
test_sync_shared
():
@
T
.
prim_func
(
private
=
True
)
@
T
.
prim_func
(
private
=
True
)
def
func
(
A
:
T
.
Buffer
((
4
,
4
),
"float32"
),
E
:
T
.
Buffer
((
4
,
4
),
"float32"
)):
def
func
(
A
:
T
.
Buffer
((
4
,
4
),
"float32"
),
E
:
T
.
Buffer
((
4
,
4
),
"float32"
)):
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
1
)
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
1
)
...
@@ -113,7 +106,6 @@ def test_sync_shared():
...
@@ -113,7 +106,6 @@ def test_sync_shared():
@
tvm
.
testing
.
requires_cuda
@
tvm
.
testing
.
requires_cuda
def
test_sync_let_stmt
():
def
test_sync_let_stmt
():
@
T
.
prim_func
(
private
=
True
)
@
T
.
prim_func
(
private
=
True
)
def
func
(
A
:
T
.
Buffer
((
16
*
512
),
"float32"
)):
def
func
(
A
:
T
.
Buffer
((
16
*
512
),
"float32"
)):
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
16
)
blockIdx_x
=
T
.
launch_thread
(
"blockIdx.x"
,
16
)
...
@@ -136,9 +128,9 @@ def test_sync_let_stmt():
...
@@ -136,9 +128,9 @@ def test_sync_let_stmt():
in_thread_A_temp_1
[
0
]
=
A_temp
in_thread_A_temp_1
[
0
]
=
A_temp
cross_thread_A_temp_1
=
T
.
Buffer
((
1
,),
data
=
cross_thread_A_temp
,
scope
=
"local"
)
cross_thread_A_temp_1
=
T
.
Buffer
((
1
,),
data
=
cross_thread_A_temp
,
scope
=
"local"
)
with
T
.
attr
(
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x0
,
y0
:
x0
+
y0
,
[
T
.
float32
(
0
)]),
T
.
comm_reducer
(
lambda
x0
,
y0
:
x0
+
y0
,
[
T
.
float32
(
0
)]),
"reduce_scope"
,
"reduce_scope"
,
T
.
reinterpret
(
"handle"
,
T
.
uint64
(
0
)),
T
.
reinterpret
(
"handle"
,
T
.
uint64
(
0
)),
):
):
T
.
tvm_thread_allreduce
(
T
.
tvm_thread_allreduce
(
T
.
uint32
(
1
),
T
.
uint32
(
1
),
...
@@ -190,16 +182,19 @@ def test_sync_let_stmt():
...
@@ -190,16 +182,19 @@ def test_sync_let_stmt():
@
tilelang
.
testing
.
requires_cuda
@
tilelang
.
testing
.
requires_cuda
def
test_sync_shared_dyn_stmatrix_loop_hoist
():
def
test_sync_shared_dyn_stmatrix_loop_hoist
():
@
T
.
prim_func
@
T
.
prim_func
def
func
():
def
func
():
buf_dyn_shmem
=
T
.
alloc_buffer
((
98304
,),
"uint8"
,
scope
=
"shared.dyn"
)
buf_dyn_shmem
=
T
.
alloc_buffer
((
98304
,),
"uint8"
,
scope
=
"shared.dyn"
)
tx
=
T
.
launch_thread
(
"threadIdx.x"
,
384
)
tx
=
T
.
launch_thread
(
"threadIdx.x"
,
384
)
for
i
in
T
.
unroll
(
8
):
for
i
in
T
.
unroll
(
8
):
off
=
(
off
=
(
i
//
4
*
8192
+
tx
//
32
*
1024
+
tx
%
16
*
64
+
i
//
4
*
8192
(
tx
%
8
//
4
+
i
%
4
//
2
)
%
2
*
32
+
(
tx
%
4
//
2
+
i
%
2
)
%
2
*
16
+
+
tx
//
32
*
1024
(
tx
%
32
//
16
+
tx
%
2
)
%
2
*
8
)
+
tx
%
16
*
64
+
(
tx
%
8
//
4
+
i
%
4
//
2
)
%
2
*
32
+
(
tx
%
4
//
2
+
i
%
2
)
%
2
*
16
+
(
tx
%
32
//
16
+
tx
%
2
)
%
2
*
8
)
T
.
evaluate
(
T
.
evaluate
(
T
.
call_intrin
(
T
.
call_intrin
(
"handle"
,
"handle"
,
...
@@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist():
...
@@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist():
2
,
2
,
),
),
T
.
int32
(
2
),
T
.
int32
(
2
),
))
)
)
mod
=
tvm
.
IRModule
({
"main"
:
func
})
mod
=
tvm
.
IRModule
({
"main"
:
func
})
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment