Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
143b5222
Unverified
Commit
143b5222
authored
Sep 12, 2025
by
Jiaxing Ding
Committed by
GitHub
Sep 12, 2025
Browse files
[AMD] support preshuffle weight mfma (#806)
Co-authored-by:
Jiaxing Ding
<
jiaxing.ding@bytedance.com
>
parent
409ab83d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
371 additions
and
19 deletions
+371
-19
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
+321
-0
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+50
-19
No files found.
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
0 → 100644
View file @
143b5222
import
torch
import
tilelang.testing
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
from
tilelang.intrinsics.mfma_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
from
tilelang.transform
import
simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
@
simplify_prim_func
def
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
,
b_preshuffle
=
False
,
):
assert
in_dtype
in
[
"float16"
,
"int8"
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"float16"
,
"float32"
,
"int32"
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
micro_size_k
=
32
block_row_warps
=
2
block_col_warps
=
2
warp_row_tiles
=
32
warp_col_tiles
=
32
# for preshuffle_b, warp_layout = {1, 4}
if
b_preshuffle
:
block_row_warps
=
1
block_col_warps
=
4
warp_row_tiles
=
128
warp_col_tiles
=
32
chunk
=
32
*
k_pack
pack_size_k
=
micro_size_k
*
k_pack
shared_scope
=
"shared"
cache_write_shared
=
False
block_M
=
block_row_warps
*
warp_row_tiles
block_N
=
block_col_warps
*
warp_col_tiles
block_K
=
chunk
A_shape
=
(
K
,
M
)
if
a_transposed
else
(
M
,
K
)
if
b_preshuffle
:
B_shape
=
(
N
//
micro_size_y
,
K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
if
b_transposed
else
(
K
//
pack_size_k
,
N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
else
:
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
a_transposed
else
(
block_M
,
block_K
)
if
b_preshuffle
:
B_shared_shape
=
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
if
b_transposed
else
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
else
:
B_shared_shape
=
(
block_N
,
block_K
)
if
b_transposed
else
(
block_K
,
block_N
)
C_shared_shape
=
(
block_M
//
micro_size_x
,
block_N
//
micro_size_y
,
micro_size_x
,
micro_size_y
,
)
warp_size
=
64
threads
=
warp_size
*
(
block_row_warps
*
block_col_warps
)
local_size_a
=
(
k_pack
*
micro_size_x
*
micro_size_k
)
//
warp_size
local_size_b
=
(
k_pack
*
micro_size_y
*
micro_size_k
)
//
warp_size
local_size_c
=
(
micro_size_x
*
micro_size_y
)
//
warp_size
warp_rows
=
warp_row_tiles
//
micro_size_x
warp_cols
=
warp_col_tiles
//
micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mfma_emitter
=
MatrixCoreIntrinEmitter
(
a_dtype
=
in_dtype
,
b_dtype
=
in_dtype
,
accum_dtype
=
accum_dtype
,
a_transposed
=
a_transposed
,
b_transposed
=
b_transposed
,
block_row_warps
=
block_row_warps
,
block_col_warps
=
block_col_warps
,
warp_row_tiles
=
warp_row_tiles
,
warp_col_tiles
=
warp_col_tiles
,
chunk
=
chunk
,
k_pack
=
k_pack
,
b_preshuffle
=
b_preshuffle
,
)
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
})
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
0
):
# Load A into shared memory
if
a_transposed
:
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_shared
)
# Load B into shared memory
if
b_preshuffle
:
if
b_transposed
:
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
):
B_shared
[
j
,
k
,
jj
,
kk
]
=
B
[
bx
*
block_N
//
micro_size_y
+
j
,
ko
*
block_K
//
pack_size_k
+
k
,
jj
,
kk
]
else
:
for
k
,
j
,
kk
,
jj
in
T
.
Parallel
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
):
B_shared
[
k
,
j
,
kk
,
jj
]
=
B
[
ko
*
block_K
//
pack_size_k
+
k
,
bx
*
block_N
//
micro_size_y
+
j
,
kk
,
jj
]
else
:
if
b_transposed
:
T
.
copy
(
B
[
bx
*
block_N
,
ko
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
k_pack
*
micro_size_k
))):
# Load A into fragment
mfma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
,
)
# Load B into fragment
mfma_emitter
.
ldmatrix_b
(
B_local
,
B_shared
,
ki
,
)
# Perform Matrix Multiplication
mfma_emitter
.
mfma
(
A_local
,
B_local
,
C_local
)
# Perform STMatrix
if
cache_write_shared
:
mfma_emitter
.
stmatrix
(
C_local
,
C_shared
,
)
# Store shared into global
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
C_shared
[
i
//
micro_size_x
,
j
//
micro_size_y
,
i
%
micro_size_x
,
j
%
micro_size_y
,
]
else
:
mfma_emitter
.
stmatrix
(
C_local
,
C
,
pid_m
=
by
,
pid_n
=
bx
,
)
return
main
def
shuffle_weight
(
x
:
torch
.
Tensor
,
layout
=
(
16
,
32
),
k_pack
=
1
,
is_transpose
=
False
,
)
->
torch
.
Tensor
:
IN
,
IK
=
layout
BK
=
IK
*
k_pack
BN
=
IN
N
,
K
=
(
x
.
shape
[
-
2
],
x
.
shape
[
-
1
])
if
is_transpose
else
(
x
.
shape
[
-
1
],
x
.
shape
[
-
2
])
assert
N
%
BN
==
0
assert
K
%
BK
==
0
x
=
x
.
view
(
N
//
BN
,
BN
,
K
//
BK
,
BK
)
if
is_transpose
else
x
.
view
(
K
//
BK
,
BK
,
N
//
BN
,
BN
)
x
=
x
.
permute
(
0
,
2
,
1
,
3
)
return
x
.
contiguous
()
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
,
b_preshuffle
=
False
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
,
b_preshuffle
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
# src_code is the generated cuda source
assert
src_code
is
not
None
A_shape
=
(
K
,
M
)
if
a_transposed
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
if
in_dtype
==
"int8"
:
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
else
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
C
=
torch
.
zeros
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
out_dtype
))
B_preshuffle
=
B
if
b_preshuffle
:
B_preshuffle
=
shuffle_weight
(
B_preshuffle
,
k_pack
=
k_pack
,
is_transpose
=
b_transposed
)
kernel
(
A
,
B_preshuffle
,
C
)
else
:
kernel
(
A
,
B
,
C
)
print
(
kernel
.
get_kernel_source
())
profiler
=
kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
()
# Ensure that the latency is not None
assert
latency
is
not
None
if
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
a_transposed
and
not
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
not
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
else
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
print
(
C
)
print
(
ref_c
)
torch
.
testing
.
assert_close
(
C
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
tilelang
.
testing
.
requires_rocm
def
test_assert_tl_matmul
():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/intrinsics/mfma_macro_generator.py
View file @
143b5222
...
@@ -53,6 +53,7 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -53,6 +53,7 @@ class MatrixCoreIntrinEmitter(object):
num_elems_per_byte
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
k_pack
:
Optional
[
int
]
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
is_m_first
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
):
):
self
.
a_dtype
=
a_dtype
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
b_dtype
=
b_dtype
...
@@ -72,6 +73,7 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -72,6 +73,7 @@ class MatrixCoreIntrinEmitter(object):
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
)
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
)
self
.
_initialize_k_pack
(
k_pack
)
self
.
_initialize_k_pack
(
k_pack
)
self
.
_initialize_is_m_first
(
is_m_first
)
self
.
_initialize_is_m_first
(
is_m_first
)
self
.
_initialize_b_preshuffle
(
b_preshuffle
)
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
...
@@ -141,6 +143,10 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -141,6 +143,10 @@ class MatrixCoreIntrinEmitter(object):
if
is_m_first
is
not
None
:
if
is_m_first
is
not
None
:
self
.
is_m_first
=
is_m_first
self
.
is_m_first
=
is_m_first
def
_initialize_b_preshuffle
(
self
,
b_preshuffle
:
Optional
[
bool
]
=
False
):
if
b_preshuffle
is
not
None
:
self
.
b_preshuffle
=
b_preshuffle
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
from
.mfma_layout
import
(
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_16x4_to_local_64x1_layout_A
,
...
@@ -288,26 +294,51 @@ class MatrixCoreIntrinEmitter(object):
...
@@ -288,26 +294,51 @@ class MatrixCoreIntrinEmitter(object):
):
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
# 4 dim
for
j
in
T
.
serial
(
warp_cols
):
if
self
.
b_preshuffle
:
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
if
is_transposed
:
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
for
j
in
T
.
serial
(
warp_cols
):
l
,
r
=
(
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
l
,
r
=
(
)
warp_n
*
warp_cols
+
j
,
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
r
+
col
]
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
else
:
for
j
in
T
.
serial
(
warp_cols
):
if
is_transposed
:
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
j
in
T
.
serial
(
warp_cols
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
l
,
r
=
(
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
l
,
r
=
(
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
)
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment