Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
48c9a352
Unverified
Commit
48c9a352
authored
Sep 23, 2025
by
Jiaxing Ding
Committed by
GitHub
Sep 23, 2025
Browse files
[AMD] refactor MatrixCoreIntrinEmitter (#860)
parent
b12a63cf
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
269 additions
and
116 deletions
+269
-116
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+4
-0
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
+39
-72
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+226
-44
No files found.
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
48c9a352
...
...
@@ -234,6 +234,10 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
if
__name__
==
"__main__"
:
...
...
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
View file @
48c9a352
...
...
@@ -3,8 +3,7 @@ import tilelang.testing
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
from
tilelang.intrinsics.mfma_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
from
tilelang.intrinsics.mfma_macro_generator
import
MatrixCorePreshuffleIntrinEmitter
from
tilelang.transform
import
simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
...
...
@@ -22,16 +21,8 @@ def tl_matmul(
b_transposed
=
True
,
k_pack
=
1
,
b_preshuffle
=
False
,
b_g2l_load
=
False
,
):
assert
in_dtype
in
[
"float16"
,
"int8"
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"float16"
,
"float32"
,
"int32"
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
...
...
@@ -47,15 +38,14 @@ def tl_matmul(
if
b_preshuffle
:
block_row_warps
=
1
block_col_warps
=
4
warp_row_tiles
=
128
warp_col_tiles
=
32
warp_row_tiles
=
64
warp_col_tiles
=
16
chunk
=
3
2
*
k_pack
chunk
=
2
56
*
k_pack
pack_size_k
=
micro_size_k
*
k_pack
shared_scope
=
"shared"
cache_write_shared
=
False
block_M
=
block_row_warps
*
warp_row_tiles
block_N
=
block_col_warps
*
warp_col_tiles
...
...
@@ -68,6 +58,7 @@ def tl_matmul(
pack_size_k
,
micro_size_y
)
else
:
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
a_transposed
else
(
block_M
,
block_K
)
if
b_preshuffle
:
B_shared_shape
=
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
...
...
@@ -76,12 +67,6 @@ def tl_matmul(
micro_size_y
)
else
:
B_shared_shape
=
(
block_N
,
block_K
)
if
b_transposed
else
(
block_K
,
block_N
)
C_shared_shape
=
(
block_M
//
micro_size_x
,
block_N
//
micro_size_y
,
micro_size_x
,
micro_size_y
,
)
warp_size
=
64
threads
=
warp_size
*
(
block_row_warps
*
block_col_warps
)
...
...
@@ -92,7 +77,7 @@ def tl_matmul(
warp_cols
=
warp_col_tiles
//
micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mfma_emitter
=
MatrixCoreIntrinEmitter
(
mfma_emitter
=
MatrixCore
Preshuffle
IntrinEmitter
(
a_dtype
=
in_dtype
,
b_dtype
=
in_dtype
,
accum_dtype
=
accum_dtype
,
...
...
@@ -117,7 +102,6 @@ def tl_matmul(
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
...
...
@@ -126,12 +110,15 @@ def tl_matmul(
A_shared
:
make_swizzle_layout
(
A_shared
),
})
num_ko
=
K
//
block_K
num_ki
=
block_K
//
(
k_pack
*
micro_size_k
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
(
K
//
block_K
)
,
num_stages
=
0
):
for
ko
in
T
.
Pipelined
(
num_ko
,
num_stages
=
0
):
# Load A into shared memory
if
a_transposed
:
...
...
@@ -140,7 +127,7 @@ def tl_matmul(
T
.
copy
(
A
[
by
*
block_M
,
ko
*
block_K
],
A_shared
)
# Load B into shared memory
if
b_
preshuffl
e
:
if
b_
g2l_load
is
Fals
e
:
if
b_transposed
:
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
...
...
@@ -153,22 +140,21 @@ def tl_matmul(
micro_size_y
):
B_shared
[
k
,
j
,
kk
,
jj
]
=
B
[
ko
*
block_K
//
pack_size_k
+
k
,
bx
*
block_N
//
micro_size_y
+
j
,
kk
,
jj
]
else
:
if
b_transposed
:
T
.
copy
(
B
[
bx
*
block_N
,
ko
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
k_pack
*
micro_size_k
))
):
for
ki
in
T
.
serial
(
0
,
num_ki
):
# Load A
into fragment
# Load A
S2L
mfma_emitter
.
ldmatrix_a
(
A_local
,
A_shared
,
ki
,
)
# Load B into fragment
if
b_g2l_load
:
# Load B G2L
mfma_emitter
.
ldmatrix_b
(
B_local
,
B
,
ki
+
ko
*
num_ki
,
pid_m
=
by
,
pid_n
=
bx
)
else
:
# Load B S2L
mfma_emitter
.
ldmatrix_b
(
B_local
,
B_shared
,
...
...
@@ -179,21 +165,6 @@ def tl_matmul(
mfma_emitter
.
mfma
(
A_local
,
B_local
,
C_local
)
# Perform STMatrix
if
cache_write_shared
:
mfma_emitter
.
stmatrix
(
C_local
,
C_shared
,
)
# Store shared into global
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
C
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
C_shared
[
i
//
micro_size_x
,
j
//
micro_size_y
,
i
%
micro_size_x
,
j
%
micro_size_y
,
]
else
:
mfma_emitter
.
stmatrix
(
C_local
,
C
,
...
...
@@ -232,9 +203,10 @@ def assert_tl_matmul_correctness(M,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
,
b_preshuffle
=
False
):
b_preshuffle
=
False
,
b_g2l_load
=
False
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
,
b_preshuffle
)
k_pack
,
b_preshuffle
,
b_g2l_load
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
...
...
@@ -285,30 +257,25 @@ def assert_tl_matmul_correctness(M,
print
(
C
)
print
(
ref_c
)
torch
.
testing
.
assert_close
(
C
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
tilelang
.
testing
.
requires_rocm
def
test_assert_tl_matmul
():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
48c9a352
...
...
@@ -293,32 +293,6 @@ class MatrixCoreIntrinEmitter(object):
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
# 4 dim
if
self
.
b_preshuffle
:
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
...
...
@@ -327,8 +301,9 @@ class MatrixCoreIntrinEmitter(object):
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
...
...
@@ -337,8 +312,8 @@ class MatrixCoreIntrinEmitter(object):
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -425,3 +400,210 @@ class MatrixCoreIntrinEmitter(object):
return
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
b_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
,
a_transposed
:
bool
=
False
,
b_transposed
:
bool
=
False
,
block_row_warps
:
int
=
2
,
block_col_warps
:
int
=
2
,
warp_row_tiles
:
int
=
8
,
warp_col_tiles
:
int
=
8
,
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
a_preshuffle
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
accum_dtype
=
accum_dtype
self
.
a_transposed
=
a_transposed
self
.
b_transposed
=
b_transposed
# Hint Information
self
.
block_row_warps
=
block_row_warps
self
.
block_col_warps
=
block_col_warps
self
.
warp_row_tiles
=
warp_row_tiles
self
.
warp_col_tiles
=
warp_col_tiles
self
.
chunk
=
chunk
self
.
_initialize_k_dim
(
a_dtype
)
self
.
_initialize_abbrev
(
a_dtype
,
b_dtype
,
accum_dtype
)
self
.
_initialize_local_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
,
self
.
WARP_SIZE
)
self
.
_initialize_mfma_prefix
(
self
.
k_dim
)
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
)
self
.
_initialize_k_pack
(
k_pack
)
self
.
_initialize_is_m_first
(
is_m_first
)
self
.
_initialize_preshuffle
(
a_preshuffle
,
b_preshuffle
)
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
num_elems_per_byte
=
num_elems_per_byte
def
_initialize_preshuffle
(
self
,
a_preshuffle
:
bool
,
b_preshuffle
:
bool
):
if
a_preshuffle
is
not
None
:
self
.
a_preshuffle
=
a_preshuffle
if
b_preshuffle
is
not
None
:
self
.
b_preshuffle
=
b_preshuffle
def
ldmatrix_a
(
self
,
A_local_buf
,
A_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
micro_size_k
=
self
.
micro_size_k
local_size_a
=
self
.
local_size_a
k_pack
=
self
.
k_pack
is_transposed
=
self
.
a_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
False
)
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
# no preshuffle, use the default implementation
if
self
.
a_preshuffle
is
False
:
return
super
().
ldmatrix_a
(
A_local_buf
,
A_buf
,
ki
,
rk
)
def
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
(
pid_m
*
self
.
block_row_warps
+
warp_m
)
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
l
,
r
,
row
,
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
(
pid_m
*
self
.
block_row_warps
+
warp_m
)
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
l
,
r
,
row
,
col
]
@
T
.
macro
def
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_m
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
else
:
print
(
self
.
a_preshuffle
)
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
micro_size_k
=
self
.
micro_size_k
local_size_b
=
self
.
local_size_b
k_pack
=
self
.
k_pack
is_transposed
=
self
.
b_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
True
)
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
if
self
.
b_preshuffle
is
False
:
return
super
().
ldmatrix_b
(
B_local_buf
,
B_buf
,
ki
,
rk
,
pid_m
,
pid_n
)
@
T
.
macro
def
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
(
pid_n
*
self
.
block_col_warps
+
warp_n
)
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
(
pid_n
*
self
.
block_col_warps
+
warp_n
)
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
l
,
r
,
row
,
col
]
@
T
.
macro
def
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment