Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bb62f6bf
Commit
bb62f6bf
authored
Dec 22, 2025
by
qisan
Browse files
[Bugfix] Pass pre commit check
parent
667632cc
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
138 deletions
+114
-138
examples/gemm/example_gemm_intrinsics_dcu.py
examples/gemm/example_gemm_intrinsics_dcu.py
+11
-11
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+1
-1
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
+19
-32
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+83
-94
No files found.
examples/gemm/example_gemm_intrinsics_dcu.py
View file @
bb62f6bf
...
...
@@ -4,7 +4,8 @@ import tilelang
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mmac_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
MatrixCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang
import
disable_cache
...
...
@@ -107,7 +108,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -115,10 +115,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -126,7 +128,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
stage
):
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
A_shared
[
i
,
k
]
=
A
[
by
*
block_M
+
i
,
ko
*
block_K
+
k
]
...
...
@@ -136,7 +137,6 @@ def tl_matmul(
B_shared
[
j
,
k
]
=
B
[
bx
*
block_N
+
j
,
ko
*
block_K
+
k
]
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
mmac_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
)
...
...
src/target/codegen_hip.cc
View file @
bb62f6bf
...
...
@@ -978,7 +978,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
// arg 11: C accumulator index
ICHECK
(
op
->
args
.
size
()
==
12U
)
<<
"Invalid number of arguments for tvm_m
f
ma"
;
<<
"Invalid number of arguments for tvm_mma
c
"
;
std
::
string
prefix
=
Downcast
<
StringImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
A_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
B_layout
=
Downcast
<
StringImm
>
(
op
->
args
[
2
])
->
value
;
...
...
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
View file @
bb62f6bf
...
...
@@ -3,10 +3,12 @@ import tilelang.testing
from
tilelang
import
tvm
as
tvm
from
tvm
import
DataType
import
tilelang.language
as
T
# from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mmac_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
MatrixCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
...
...
@@ -111,7 +113,6 @@ def tl_matmul(
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -119,10 +120,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -130,7 +133,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
0
):
# Load A into shared memory
if
a_transposed
:
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -144,7 +146,6 @@ def tl_matmul(
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
k_pack
*
micro_size_k
))):
# Load A into fragment
mmac_emitter
.
ldmatrix_a
(
A_local
,
...
...
@@ -180,17 +181,8 @@ def tl_matmul(
return
main
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
)
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
...
...
@@ -218,16 +210,13 @@ def assert_tl_matmul_correctness(M,
if
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
a_transposed
and
not
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
not
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
else
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
...
...
@@ -245,10 +234,8 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
# assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
...
...
tilelang/intrinsics/mmac_macro_generator.py
View file @
bb62f6bf
...
...
@@ -5,7 +5,8 @@ from tvm import DataType
from
tvm.tir
import
PrimExpr
from
tvm.runtime
import
convert
from
.utils
import
(
mfma_store_index_map
,)
mfma_store_index_map
,
)
lift
=
convert
...
...
@@ -77,7 +78,7 @@ class MatrixCoreIntrinEmitter:
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
...
...
@@ -107,19 +108,9 @@ class MatrixCoreIntrinEmitter:
def
_initialize_mmac_prefix
(
self
,
k_dim
=
16
):
in_dtype
,
out_dtype
=
self
.
a_dtype
,
self
.
accum_dtype
M_DIM
,
N_DIM
=
self
.
M_DIM
,
self
.
N_DIM
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
in_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"bfloat16"
:
"bf16"
}[
in_dtype
]
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
in_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"bfloat16"
:
"bf16"
}[
in_dtype
]
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
...
...
@@ -167,41 +158,53 @@ class MatrixCoreIntrinEmitter:
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
is_b
:
index_map
=
shared_16x4_to_local_64x1_layout_A
if
transposed
else
shared_4x16_to_local_64x1_layout_B
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
)
elif
k_dim
==
16
:
index_map
=
shared_16x16_to_local_64x4_layout_B
if
transposed
else
shared_16x16_to_local_64x4_layout_A
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
)
if
is_b
:
index_map
=
shared_16x16_to_local_64x4_layout_A
if
transposed
else
shared_16x16_to_local_64x4_layout_B
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
)
elif
k_dim
==
32
:
index_map
=
shared_16x32_to_local_64x8_layout_B
if
transposed
else
shared_16x32_to_local_64x8_layout_A
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
)
if
is_b
:
index_map
=
shared_16x32_to_local_64x8_layout_A
if
transposed
else
shared_16x32_to_local_64x8_layout_B
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
)
elif
k_dim
==
64
:
index_map
=
shared_16x64_to_local_64x16_layout_B
if
transposed
else
shared_16x64_to_local_64x16_layout_A
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
)
if
is_b
:
index_map
=
shared_16x64_to_local_64x16_layout_A
if
transposed
else
shared_16x64_to_local_64x16_layout_B
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
reverse_index_map
=
(
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
)
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
return
index_map
,
reverse_index_map
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
'''
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
"""
WARP_SIZE
=
self
.
WARP_SIZE
block_row_warps
=
self
.
block_row_warps
block_col_warps
=
self
.
block_col_warps
...
...
@@ -211,16 +214,18 @@ class MatrixCoreIntrinEmitter:
is_m_first
=
self
.
is_m_first
if
is_m_first
:
lane_id
,
warp_n
,
warp_m
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
lane_id
,
warp_n
,
warp_m
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
)
return
lane_id
,
warp_n
,
warp_m
else
:
lane_id
,
warp_m
,
warp_n
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
lane_id
,
warp_m
,
warp_n
=
(
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
)
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
,
ki
,
rk
=
0
):
...
...
@@ -249,18 +254,14 @@ class MatrixCoreIntrinEmitter:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
+
row
,
r
+
col
]
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
+
row
,
r
+
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
+
row
,
r
+
col
]
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
+
row
,
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -289,28 +290,22 @@ class MatrixCoreIntrinEmitter:
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -374,14 +369,13 @@ class MatrixCoreIntrinEmitter:
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
j
*
(
warp_rows
*
local_size_out
)
+
i
*
local_size_out
+
local_id
]
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
j
*
(
warp_rows
*
local_size_out
)
+
i
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_n
*
warp_cols
+
j
,
warp_m
*
warp_rows
+
i
,
row
,
col
]
=
C_local_buf
[
j
*
warp_rows
*
local_size_out
+
i
*
local_size_out
+
local_id
]
C_buf
[
warp_n
*
warp_cols
+
j
,
warp_m
*
warp_rows
+
i
,
row
,
col
]
=
C_local_buf
[
j
*
warp_rows
*
local_size_out
+
i
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
...
...
@@ -389,18 +383,18 @@ class MatrixCoreIntrinEmitter:
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
C_buf
[(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
C_buf
[
(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
)
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
...
...
@@ -420,7 +414,6 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
a_preshuffle
:
bool
|
None
=
False
,
b_preshuffle
:
bool
|
None
=
False
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
accum_dtype
=
accum_dtype
...
...
@@ -444,7 +437,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
def
_initialize_preshuffle
(
self
,
a_preshuffle
:
bool
,
b_preshuffle
:
bool
):
...
...
@@ -513,19 +506,19 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_m
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
return
(
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_cols
=
self
.
warp_cols
...
...
@@ -582,28 +575,24 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
>>
2
)
+
((
tx
&
3
)
<<
2
)
+
((
tx
>>
4
)
<<
4
),
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
return
(
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment