Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
60567ba3
Unverified
Commit
60567ba3
authored
Oct 28, 2025
by
Jiaxing Ding
Committed by
GitHub
Oct 28, 2025
Browse files
[AMD] Supoort T.gemm_v2 for AMD Backend (#1136)
parent
7d389a43
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1132 additions
and
35 deletions
+1132
-35
examples/plot_layout/fragment_mfma_load_a.py
examples/plot_layout/fragment_mfma_load_a.py
+133
-0
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py
.../python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py
+501
-0
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+272
-31
tilelang/tileop/gemm/__init__.py
tilelang/tileop/gemm/__init__.py
+11
-4
tilelang/tileop/gemm/gemm_mfma.py
tilelang/tileop/gemm/gemm_mfma.py
+215
-0
No files found.
examples/plot_layout/fragment_mfma_load_a.py
0 → 100644
View file @
60567ba3
import
tilelang.language
as
T
from
typing
import
Literal
,
Callable
from
tvm.tir
import
IndexMap
from
tilelang.intrinsics.utils
import
get_mma_micro_size
from
tilelang.intrinsics.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_16x16_to_local_64x4_layout_A
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x64_to_local_64x16_layout_A
,
)
def
make_mfma_load_base_layout
(
dtype
:
str
=
"float16"
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
,
k_dim
:
int
=
16
,
transposed
:
bool
=
False
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
dtype : str
The data type of the matrix.
matrix : Literal["A", "B"]
The mfma operand to be loaded.
k_dim : int
The k dimension of the mfma.
transposed : bool
Whether the matrix is transposed, by default False.
Returns
-------
T.Fragment
Describes how threads and indices in fragment are laid out.
"""
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr_a
:
Callable
=
None
transform_func_sr_b
:
Callable
=
None
if
k_dim
==
4
:
transform_func_sr_a
=
shared_16x4_to_local_64x1_layout_A
transform_func_sr_b
=
shared_16x4_to_local_64x1_layout_A
elif
k_dim
==
16
:
transform_func_sr_a
=
shared_16x16_to_local_64x4_layout_A
transform_func_sr_b
=
shared_16x16_to_local_64x4_layout_A
elif
k_dim
==
32
:
transform_func_sr_a
=
shared_16x32_to_local_64x8_layout_A
transform_func_sr_b
=
shared_16x32_to_local_64x8_layout_A
elif
k_dim
==
64
:
transform_func_sr_a
=
shared_16x64_to_local_64x16_layout_A
transform_func_sr_b
=
shared_16x64_to_local_64x16_layout_A
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
is_sr_conditions
=
[
False
]
is_sr_conditions
.
append
(
matrix
==
"A"
and
not
transposed
)
is_sr_conditions
.
append
(
matrix
==
"B"
and
transposed
)
is_sr_axis_order
=
any
(
is_sr_conditions
)
micro_size_x
,
micro_size_y
,
micro_size_k
=
get_mma_micro_size
(
dtype
)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func
:
Callable
=
None
if
matrix
==
"A"
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_x
,
micro_size_k
elif
matrix
==
"B"
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
micro_size_s
,
micro_size_r
=
micro_size_k
,
micro_size_y
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
)
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id
,
_
=
inverse_mma_load_layout
.
map_indices
([
i
,
j
])
return
lane_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
i
,
j
])
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
return
base_fragment
block_rows
=
2
block_cols
=
2
warp_rows
=
2
warp_cols
=
2
chunk
=
2
from
tilelang.tools
import
plot_layout
# ldmatrix layout 16x16
base_layout
=
make_mfma_load_base_layout
(
dtype
=
"float16"
,
matrix
=
"A"
,
transposed
=
False
)
print
(
base_layout
)
plot_layout
(
base_layout
,
name
=
"base_layout"
)
# warp layout 32x32
warp_layout
=
base_layout
.
repeat
([
warp_rows
,
warp_cols
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
print
(
warp_layout
)
plot_layout
(
warp_layout
,
name
=
"warp_layout"
)
# block layout 64x32
block_layout
=
warp_layout
.
repeat
([
block_rows
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
block_cols
)
print
(
block_layout
)
plot_layout
(
block_layout
,
name
=
"block_layout"
)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_amd.py
0 → 100644
View file @
60567ba3
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm_v2
(
A_shared
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_ss
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
,
):
program
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
latency
=
profiler
.
do_bench
(
profiler
.
func
,
warmup
=
100
)
print
(
f
"GEMM SS latency:
{
latency
}
ms"
)
def
test_gemm_ss
():
# GEMM tests for float16
run_gemm_ss
(
1024
,
1024
,
1024
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_ss
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_ss
(
1024
,
1024
,
1024
,
True
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_ss
(
1024
,
1024
,
1024
,
True
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
# GEMM tests for int8 tests
run_gemm_ss
(
1024
,
1024
,
1024
,
False
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_ss
(
1024
,
1024
,
1024
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_ss
(
1024
,
1024
,
1024
,
True
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_ss
(
1024
,
1024
,
1024
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
def
matmul_rs
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_frag_shape
=
A_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
})
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
A_shared
,
A_frag
)
T
.
gemm_v2
(
A_frag
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_rs
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
,
):
program
=
matmul_rs
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_gemm_rs
():
# GEMM tests for float16
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
# GEMM tests for int8 tests
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
def
matmul_sr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_frag_shape
=
B_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
})
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
B_shared
,
B_frag
)
T
.
gemm_v2
(
A_shared
,
B_frag
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_sr
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
,
):
program
=
matmul_sr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_gemm_sr
():
# GEMM tests for float16
run_gemm_sr
(
1024
,
1024
,
1024
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_sr
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_sr
(
1024
,
1024
,
1024
,
True
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_sr
(
1024
,
1024
,
1024
,
True
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
# GEMM tests for int8 tests
run_gemm_sr
(
1024
,
1024
,
1024
,
False
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_sr
(
1024
,
1024
,
1024
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_sr
(
1024
,
1024
,
1024
,
True
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_sr
(
1024
,
1024
,
1024
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
def
matmul_rr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
num_stages
,
threads
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_frag_shape
=
A_shared_shape
B_frag_shape
=
B_shared_shape
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
A_frag
=
T
.
alloc_fragment
(
A_frag_shape
,
in_dtype
)
B_frag
=
T
.
alloc_fragment
(
B_frag_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
annotate_layout
({
A_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
A_shared
),
B_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
B_shared
),
})
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
A_shared
,
A_frag
)
T
.
copy
(
B_shared
,
B_frag
)
T
.
gemm_v2
(
A_frag
,
B_frag
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
run_gemm_rr
(
M
,
N
,
K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
block_M
,
block_N
,
block_K
,
num_stages
=
3
,
num_threads
=
256
,
):
program
=
matmul_rr
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
dtypeAccum
,
num_stages
,
num_threads
,
)
kernel
=
tilelang
.
compile
(
program
,
out_idx
=
[
2
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
print
(
program
)
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Normal
)
def
ref_program
(
A
,
B
):
import
torch
if
trans_A
:
A
=
A
.
T
if
trans_B
:
B
=
B
.
T
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_gemm_rr
():
# GEMM tests for float16
run_gemm_rr
(
1024
,
1024
,
1024
,
False
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rr
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rr
(
1024
,
1024
,
1024
,
True
,
False
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rr
(
1024
,
1024
,
1024
,
True
,
True
,
"float16"
,
"float16"
,
"float32"
,
128
,
128
,
32
)
# GEMM tests for int8 tests
run_gemm_rr
(
1024
,
1024
,
1024
,
False
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_rr
(
1024
,
1024
,
1024
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_rr
(
1024
,
1024
,
1024
,
True
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
run_gemm_rr
(
1024
,
1024
,
1024
,
True
,
True
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
32
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/intrinsics/mfma_macro_generator.py
View file @
60567ba3
...
@@ -2,10 +2,32 @@ from __future__ import annotations
...
@@ -2,10 +2,32 @@ from __future__ import annotations
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tvm
import
DataType
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
from
tvm.tir
import
PrimExpr
,
IndexMap
,
Buffer
,
Var
from
tvm.runtime
import
convert
from
tvm.runtime
import
convert
from
.utils
import
(
from
.utils
import
(
mfma_store_index_map
,)
mfma_store_index_map
,)
from
typing
import
Literal
,
Callable
from
tilelang.utils
import
is_fragment
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_4x16_to_local_64x1_layout_B
,
shared_16x16_to_local_64x4_layout_A
,
shared_16x16_to_local_64x4_layout_B
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_B
,
shared_16x64_to_local_64x16_layout_A
,
shared_16x64_to_local_64x16_layout_B
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
thread_id_shared_access_64x16_to_16x64_layout_A
,
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
lift
=
convert
lift
=
convert
...
@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -53,6 +75,7 @@ class MatrixCoreIntrinEmitter:
k_pack
:
int
|
None
=
None
,
k_pack
:
int
|
None
=
None
,
is_m_first
:
bool
|
None
=
False
,
is_m_first
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
thread_var
:
Var
|
None
=
None
,
):
):
self
.
a_dtype
=
a_dtype
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
b_dtype
=
b_dtype
...
@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -79,6 +102,7 @@ class MatrixCoreIntrinEmitter:
self
.
reduce_k
=
reduce_k
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
thread_var
=
thread_var
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
if
isinstance
(
a_dtype
,
str
):
if
isinstance
(
a_dtype
,
str
):
...
@@ -147,24 +171,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -147,24 +171,6 @@ class MatrixCoreIntrinEmitter:
self
.
b_preshuffle
=
b_preshuffle
self
.
b_preshuffle
=
b_preshuffle
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_4x16_to_local_64x1_layout_B
,
shared_16x16_to_local_64x4_layout_A
,
shared_16x16_to_local_64x4_layout_B
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_B
,
shared_16x64_to_local_64x16_layout_A
,
shared_16x64_to_local_64x16_layout_B
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
thread_id_shared_access_64x16_to_16x64_layout_A
,
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
k_dim
=
self
.
k_dim
*
self
.
k_pack
k_dim
=
self
.
k_dim
*
self
.
k_pack
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
...
@@ -200,6 +206,22 @@ class MatrixCoreIntrinEmitter:
...
@@ -200,6 +206,22 @@ class MatrixCoreIntrinEmitter:
return
index_map
,
reverse_index_map
return
index_map
,
reverse_index_map
def
get_store_index_map
(
self
,
inverse
:
bool
=
False
)
->
IndexMap
:
warp_size
,
local_size_c
=
self
.
WARP_SIZE
,
self
.
local_size_out
index_map
=
IndexMap
.
from_func
(
mfma_store_index_map
,
index_dtype
=
"int32"
)
if
not
inverse
:
return
index_map
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
def
get_thread_binding
(
self
):
if
self
.
thread_var
is
None
:
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
assert
current_frame
is
not
None
,
"Must be called in a T.Kernel Frame"
return
current_frame
.
get_thread_binding
()
else
:
return
self
.
thread_var
def
extract_thread_binding
(
self
,
def
extract_thread_binding
(
self
,
thread_id
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
...
@@ -238,8 +260,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -238,8 +260,7 @@ class MatrixCoreIntrinEmitter:
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
k_pack
=
self
.
k_pack
k_pack
=
self
.
k_pack
is_transposed
=
self
.
a_transposed
is_transposed
=
self
.
a_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
self
.
get_thread_binding
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
False
)
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
False
)
@
T
.
macro
@
T
.
macro
...
@@ -279,8 +300,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -279,8 +300,7 @@ class MatrixCoreIntrinEmitter:
local_size_b
=
self
.
local_size_b
local_size_b
=
self
.
local_size_b
k_pack
=
self
.
k_pack
k_pack
=
self
.
k_pack
is_transposed
=
self
.
b_transposed
is_transposed
=
self
.
b_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
self
.
get_thread_binding
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
True
)
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
True
)
@
T
.
macro
@
T
.
macro
...
@@ -316,7 +336,11 @@ class MatrixCoreIntrinEmitter:
...
@@ -316,7 +336,11 @@ class MatrixCoreIntrinEmitter:
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mfma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
mfma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
PrimExpr
|
None
=
0
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
...
@@ -329,8 +353,15 @@ class MatrixCoreIntrinEmitter:
...
@@ -329,8 +353,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype
=
b_dtype
if
local_size_b
==
1
else
f
"
{
b_dtype
}
x
{
local_size_b
}
"
compute_b_dtype
=
b_dtype
if
local_size_b
==
1
else
f
"
{
b_dtype
}
x
{
local_size_b
}
"
compute_out_dtype
=
out_dtype
if
local_size_out
==
1
else
f
"
{
out_dtype
}
x
{
local_size_out
}
"
compute_out_dtype
=
out_dtype
if
local_size_out
==
1
else
f
"
{
out_dtype
}
x
{
local_size_out
}
"
a_is_fragment
=
is_fragment
(
A_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
a_local_stride
:
PrimExpr
=
k_inner
*
warp_rows
*
local_size_a
if
a_is_fragment
else
0
b_local_stride
:
PrimExpr
=
k_inner
*
warp_cols
*
local_size_b
if
b_is_fragment
else
0
print
(
a_local_stride
,
b_local_stride
)
@
T
.
macro
@
T
.
macro
def
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
_warp_m
f
ma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
for
kp
,
i
,
j
in
T
.
grid
(
k_pack
,
warp_rows
,
warp_cols
):
for
kp
,
i
,
j
in
T
.
grid
(
k_pack
,
warp_rows
,
warp_cols
):
T
.
tvm_mfma
(
T
.
tvm_mfma
(
mfma_suffix
,
mfma_suffix
,
...
@@ -340,15 +371,15 @@ class MatrixCoreIntrinEmitter:
...
@@ -340,15 +371,15 @@ class MatrixCoreIntrinEmitter:
compute_b_dtype
,
compute_b_dtype
,
compute_out_dtype
,
compute_out_dtype
,
B_local_buf
.
data
,
B_local_buf
.
data
,
((
j
*
k_pack
+
kp
)
*
local_size_b
)
//
local_size_b
,
(
b_local_stride
+
(
j
*
k_pack
+
kp
)
*
local_size_b
)
//
local_size_b
,
A_local_buf
.
data
,
A_local_buf
.
data
,
((
i
*
k_pack
+
kp
)
*
local_size_a
)
//
local_size_a
,
(
a_local_stride
+
(
i
*
k_pack
+
kp
)
*
local_size_a
)
//
local_size_a
,
C_local_buf
.
data
,
C_local_buf
.
data
,
(
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
)
//
local_size_out
,
(
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
)
//
local_size_out
,
dtype
=
compute_out_dtype
,
dtype
=
compute_out_dtype
,
)
)
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
return
_warp_m
f
ma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
def
stmatrix
(
self
,
C_local_buf
,
C_buf
,
pid_m
=
None
,
pid_n
=
None
):
def
stmatrix
(
self
,
C_local_buf
,
C_buf
,
pid_m
=
None
,
pid_n
=
None
):
block_row_warps
=
self
.
block_row_warps
block_row_warps
=
self
.
block_row_warps
...
@@ -356,8 +387,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -356,8 +387,7 @@ class MatrixCoreIntrinEmitter:
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_out
=
self
.
local_size_out
local_size_out
=
self
.
local_size_out
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
self
.
get_thread_binding
()
thread_binding
=
current_frame
.
get_thread_binding
()
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
BLOCK_M
=
block_row_warps
*
warp_rows
BLOCK_M
=
block_row_warps
*
warp_rows
BLOCK_N
=
block_col_warps
*
warp_cols
BLOCK_N
=
block_col_warps
*
warp_cols
...
@@ -366,7 +396,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -366,7 +396,7 @@ class MatrixCoreIntrinEmitter:
assert
C_buf_dims
in
{
2
,
4
},
"C_buf should be 2D or 4D"
assert
C_buf_dims
in
{
2
,
4
},
"C_buf should be 2D or 4D"
# STS
# STS
# MMA Store must be in simulated instead of TVM Intrins
# M
F
MA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
# equal to the warp_size
@
T
.
macro
@
T
.
macro
...
@@ -400,6 +430,217 @@ class MatrixCoreIntrinEmitter:
...
@@ -400,6 +430,217 @@ class MatrixCoreIntrinEmitter:
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
C_local_buf
,
C_buf
,
thread_binding
)
def
make_mfma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
matrix_is_a
:
bool
=
matrix
==
"A"
matrix_is_b
:
bool
=
matrix
==
"B"
transposed
=
self
.
a_transposed
if
matrix_is_a
else
self
.
b_transposed
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
# sr also can represent a non-transposed basic layout
# then rs also can represent a transposed basic layout
transform_func_sr_a
:
Callable
=
None
transform_func_sr_b
:
Callable
=
None
k_dim
=
self
.
k_dim
*
self
.
k_pack
if
k_dim
==
4
:
transform_func_sr_a
=
shared_16x4_to_local_64x1_layout_A
transform_func_sr_b
=
shared_16x4_to_local_64x1_layout_A
elif
k_dim
==
16
:
transform_func_sr_a
=
shared_16x16_to_local_64x4_layout_A
transform_func_sr_b
=
shared_16x16_to_local_64x4_layout_A
elif
k_dim
==
32
:
transform_func_sr_a
=
shared_16x32_to_local_64x8_layout_A
transform_func_sr_b
=
shared_16x32_to_local_64x8_layout_A
elif
k_dim
==
64
:
transform_func_sr_a
=
shared_16x64_to_local_64x16_layout_A
transform_func_sr_b
=
shared_16x64_to_local_64x16_layout_A
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
is_sr_conditions
=
[
False
]
is_sr_conditions
.
append
(
matrix_is_a
and
not
transposed
)
is_sr_conditions
.
append
(
matrix_is_b
and
transposed
)
is_sr_axis_order
=
any
(
is_sr_conditions
)
transform_func
:
Callable
=
None
if
matrix_is_a
:
transform_func
=
transform_func_sr_a
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_a
(
j
,
i
)
elif
matrix_is_b
:
transform_func
=
transform_func_sr_b
if
is_sr_axis_order
else
lambda
i
,
j
:
transform_func_sr_b
(
j
,
i
)
else
:
raise
ValueError
(
f
"Unsupported matrix
{
matrix
}
"
)
assert
is_fragment
(
local_buf
),
f
"local_buf must be a fragment, but got
{
local_buf
.
scope
()
}
"
if
matrix_is_a
:
micro_size_s
,
micro_size_r
=
self
.
micro_size_x
,
self
.
micro_size_k
else
:
micro_size_r
,
micro_size_s
=
self
.
micro_size_k
,
self
.
micro_size_y
block_row_warps
,
block_col_warps
=
(
self
.
block_row_warps
,
self
.
block_col_warps
,
)
inverse_mfma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
)
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id
,
_
=
inverse_mfma_load_layout
.
map_indices
([
i
,
j
])
return
lane_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_
,
local_id
=
inverse_mfma_load_layout
.
map_indices
([
i
,
j
])
return
local_id
base_fragment
=
T
.
Fragment
(
[
micro_size_s
,
micro_size_r
]
if
is_sr_axis_order
else
[
micro_size_r
,
micro_size_s
],
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
chunk
=
self
.
chunk
warp_s
=
warp_rows
if
matrix_is_a
else
warp_cols
warp_r
=
chunk
//
micro_size_r
block_s
=
block_row_warps
if
matrix_is_a
else
block_col_warps
replicate
=
block_col_warps
if
matrix_is_a
else
block_row_warps
if
is_sr_axis_order
:
warp_fragment
=
base_fragment
.
repeat
([
warp_s
,
warp_r
],
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
block_s
,
1
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
else
:
warp_fragment
=
base_fragment
.
repeat
([
warp_r
,
warp_s
],
repeat_on_thread
=
False
,
lower_dim_first
=
True
)
if
matrix_is_a
:
block_fragment
=
warp_fragment
.
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
).
replicate
(
replicate
)
elif
matrix_is_b
:
block_fragment
=
warp_fragment
.
replicate
(
replicate
).
repeat
([
1
,
block_s
],
repeat_on_thread
=
True
,
lower_dim_first
=
True
)
else
:
raise
ValueError
(
f
"Unsupported matrix type
{
matrix
}
"
)
return
block_fragment
def
make_mfma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
"""
Create a layout function for storing MFMA results into a fragment buffer.
Parameters
----------
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
inverse_mfma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment"
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_y
local_size_out
=
self
.
local_size_out
block_row_warps
,
block_col_warps
=
self
.
block_row_warps
,
self
.
block_col_warps
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_size
=
self
.
WARP_SIZE
is_m_first
=
self
.
is_m_first
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i
,
mfma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mfma_store_layout
.
map_indices
([
mfma_i
,
mfma_j
])
if
is_m_first
:
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_size
+
lane_id
else
:
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
return
thread_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mfma_store_layout`.
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mfma_i and mfma_j are micro_size_x and micro_size_y
mfma_i
,
mfma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mfma_store_layout
.
map_indices
([
mfma_i
,
mfma_j
])
return
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
return
T
.
Fragment
(
shape
,
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
)
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
...
...
tilelang/tileop/gemm/__init__.py
View file @
60567ba3
...
@@ -8,6 +8,7 @@ import tvm.ffi
...
@@ -8,6 +8,7 @@ import tvm.ffi
from
tilelang.ir
import
GemmWarpPolicy
from
tilelang.ir
import
GemmWarpPolicy
from
.gemm_mma
import
GemmMMA
from
.gemm_mma
import
GemmMMA
from
.gemm_wgmma
import
GemmWGMMA
from
.gemm_wgmma
import
GemmWGMMA
from
.gemm_mfma
import
GemmMFMA
from
tilelang
import
_ffi_api
from
tilelang
import
_ffi_api
...
@@ -28,14 +29,18 @@ def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
...
@@ -28,14 +29,18 @@ def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
# same definition with src/op/gemm_py.h
# same definition with src/op/gemm_py.h
class
GemmInst
(
IntEnum
):
class
GemmInst
(
IntEnum
):
MMA
=
0
MMA
=
0
WGMMMA
=
1
WGMMA
=
1
MFMA
=
2
TCGEN5MMA
=
2
MFMA
=
3
def
is_mma
(
self
)
->
bool
:
def
is_mma
(
self
)
->
bool
:
return
self
==
GemmInst
.
MMA
return
self
==
GemmInst
.
MMA
def
is_wgmma
(
self
)
->
bool
:
def
is_wgmma
(
self
)
->
bool
:
return
self
==
GemmInst
.
WGMMMA
return
self
==
GemmInst
.
WGMMA
def
is_tcgen5mma
(
self
)
->
bool
:
return
self
==
GemmInst
.
TCGEN5MMA
def
is_mfma
(
self
)
->
bool
:
def
is_mfma
(
self
)
->
bool
:
return
self
==
GemmInst
.
MFMA
return
self
==
GemmInst
.
MFMA
...
@@ -115,6 +120,8 @@ class GemmPy(Node, Scriptable):
...
@@ -115,6 +120,8 @@ class GemmPy(Node, Scriptable):
elif
gemm_inst
.
is_wgmma
():
elif
gemm_inst
.
is_wgmma
():
return
GemmWGMMA
return
GemmWGMMA
elif
gemm_inst
.
is_mfma
():
elif
gemm_inst
.
is_mfma
():
raise
NotImplementedError
(
"MFMA is not implemented"
)
return
GemmMFMA
elif
gemm_inst
.
is_tcgen5mma
():
raise
NotImplementedError
(
"TCGEN5MMA is not implemented"
)
else
:
else
:
raise
ValueError
(
f
"Unsupported GEMM instruction:
{
gemm_inst
}
"
)
raise
ValueError
(
f
"Unsupported GEMM instruction:
{
gemm_inst
}
"
)
tilelang/tileop/gemm/gemm_mfma.py
0 → 100644
View file @
60567ba3
from
.gemm_base
import
GemmBase
from
tilelang.layout
import
make_swizzled_layout
from
tilelang.intrinsics.mfma_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
from
tilelang.utils.language
import
is_shared
,
is_fragment
from
tilelang
import
tvm
as
tvm
from
tvm.target
import
Target
from
tvm
import
tir
from
tilelang
import
language
as
T
from
tilelang.transform.simplify
import
_Simplify
class
GemmMFMA
(
GemmBase
):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mfma_emitter
=
MatrixCoreIntrinEmitter
(
a_dtype
=
self
.
in_dtype
,
b_dtype
=
self
.
in_dtype
,
accum_dtype
=
self
.
accum_dtype
,
a_transposed
=
self
.
trans_A
,
b_transposed
=
self
.
trans_B
,
block_row_warps
=
m_warp
,
block_col_warps
=
n_warp
,
warp_row_tiles
=
warp_row_tiles
,
warp_col_tiles
=
warp_col_tiles
,
chunk
=
self
.
chunk
,
)
if
self
.
is_gemm_ss
():
return
{
self
.
A
:
make_swizzled_layout
(
self
.
A
),
self
.
B
:
make_swizzled_layout
(
self
.
B
),
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
elif
self
.
is_gemm_sr
():
return
{
self
.
A
:
make_swizzled_layout
(
self
.
A
),
self
.
B
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
B
,
matrix
=
"B"
),
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
elif
self
.
is_gemm_rs
():
return
{
self
.
A
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
A
,
matrix
=
"A"
),
self
.
B
:
make_swizzled_layout
(
self
.
B
),
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
elif
self
.
is_gemm_rr
():
return
{
self
.
A
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
A
,
matrix
=
"A"
),
self
.
B
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
B
,
matrix
=
"B"
),
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
lower
(
self
,
layout_map
:
dict
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mfma_emitter
=
MatrixCoreIntrinEmitter
(
a_dtype
=
self
.
in_dtype
,
b_dtype
=
self
.
in_dtype
,
accum_dtype
=
self
.
accum_dtype
,
a_transposed
=
self
.
trans_A
,
b_transposed
=
self
.
trans_B
,
block_row_warps
=
m_warp
,
block_col_warps
=
n_warp
,
warp_row_tiles
=
warp_row_tiles
,
warp_col_tiles
=
warp_col_tiles
,
chunk
=
self
.
chunk
,
thread_var
=
thread_var
,
)
in_dtype
=
self
.
in_dtype
warp_rows
=
mfma_emitter
.
warp_rows
warp_cols
=
mfma_emitter
.
warp_cols
local_size_a
=
mfma_emitter
.
local_size_a
local_size_b
=
mfma_emitter
.
local_size_b
block_K
=
mfma_emitter
.
chunk
micro_size_k
=
mfma_emitter
.
micro_size_k
A_shared
=
self
.
A
B_shared
=
self
.
B
C_local
=
self
.
C
assert
block_K
>=
micro_size_k
,
f
"block_K (
{
block_K
}
) must be >= micro_size_k (
{
micro_size_k
}
)"
if
self
.
is_gemm_ss
():
@
T
.
prim_func
def
_gemm_ssr
()
->
None
:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mfma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
,
)
# Load B into fragment
mfma_emitter
.
ldmatrix_b
(
B_local
,
B_shared
,
ki
,
)
# Perform Matrix Multiplication
mfma_emitter
.
mfma
(
A_local
,
B_local
,
C_local
,
ki
)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_ssr
,
inline_let
=
True
)
elif
self
.
is_gemm_sr
():
B_local
=
self
.
B
@
T
.
prim_func
def
_gemm_srr
()
->
None
:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mfma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
,
)
# Perform Matrix Multiplication
mfma_emitter
.
mfma
(
A_local
,
B_local
,
C_local
,
ki
)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
# alloc_buffers body
# insert into parent block
return
_Simplify
(
_gemm_srr
,
inline_let
=
True
)
elif
self
.
is_gemm_rs
():
A_local
=
self
.
A
@
T
.
prim_func
def
_gemm_rsr
()
->
None
:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load B into fragment
mfma_emitter
.
ldmatrix_b
(
B_local
,
B_shared
,
ki
,
)
# Perform Matrix Multiplication
mfma_emitter
.
mfma
(
A_local
,
B_local
,
C_local
,
ki
)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_rsr
,
inline_let
=
True
)
elif
self
.
is_gemm_rr
():
A_local
=
self
.
A
B_local
=
self
.
B
@
T
.
prim_func
def
_gemm_rsr
()
->
None
:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Perform Matrix Multiplication
mfma_emitter
.
mfma
(
A_local
,
B_local
,
C_local
,
ki
)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_rsr
,
inline_let
=
True
)
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
is_gemm_ss
(
self
)
->
bool
:
return
is_shared
(
self
.
A
)
and
is_shared
(
self
.
B
)
def
is_gemm_sr
(
self
)
->
bool
:
return
is_shared
(
self
.
A
)
and
is_fragment
(
self
.
B
)
def
is_gemm_rs
(
self
)
->
bool
:
return
is_fragment
(
self
.
A
)
and
is_shared
(
self
.
B
)
def
is_gemm_rr
(
self
)
->
bool
:
return
is_fragment
(
self
.
A
)
and
is_fragment
(
self
.
B
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment