Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
f3f31091
"examples/rust/vscode:/vscode.git/clone" did not exist on "73c10ae90aa51b85eaba35cdcf60e58b272bec46"
Commit
f3f31091
authored
Apr 20, 2026
by
wangziyang
Browse files
add simple mmac_layout
parent
8443d88e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
5 deletions
+97
-5
tilelang/intrinsics/mmac_layout.py
tilelang/intrinsics/mmac_layout.py
+57
-4
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+40
-1
No files found.
tilelang/intrinsics/mmac_layout.py
View file @
f3f31091
def
thread_id_shared_access_64x4_to_16x16_layout_C_n_m
(
thread_id
,
local_id
):
i
=
thread_id
%
16
j
=
local_id
+
(
thread_id
//
16
)
*
4
return
i
,
j
\ No newline at end of file
from
tvm
import
DataType
from
tvm.runtime
import
convert
import
tilelang.language
as
T
'''
256x16 A share mem layout
0 16 32 48
1 17 33 49
2 18 34 50
3 19 35 51
---------------------
4 20 36 52
5 21 37 53
16x16 A Reg layout
0 16 32 48
1 17 33 49
2 18 34 50
3 19 35 51
'''
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
# local_id: which element within the thread's 4-element vector
#
# Layout (根据你的图示):
# j=0-3 j=4-7 j=8-11 j=12-15
# T0-T3 T16-T19 T32-T35 T48-T51 (i=0)
# T4-T7 T20-T23 T36-T39 T52-T55 (i=1)
# T8-T11 T24-T27 T40-T43 T56-T59 (i=2)
# T12-T15 T28-T31 T44-T47 T60-T63 (i=3)
thread_id
=
i
+
16
*
(
j
//
4
)
# 0-63
local_id
=
j
%
4
# 0-3: 4 consecutive elements per thread
return
thread_id
,
local_id
def
thread_id_shared_access_64x4_to_16x16_layout_A
(
tx
,
local_id
):
# tx: thread id within warp (0-63)
# local_id: element index within thread's 4-element vector (0-3)
#
# 根据布局:
# tx=0-15 处理行 0-15 的同一位置,列块 0 (j=0-3)
# tx=16-31 处理行 0-15 的同一位置,列块 1 (j=4-7)
# tx=32-47 处理行 0-15 的同一位置,列块 2 (j=8-11)
# tx=48-63 处理行 0-15 的同一位置,列块 3 (j=12-15)
#
# 例如 tx=0, local_id=0 → (i=0, j=0)
# tx=0, local_id=1 → (i=0, j=1)
# tx=16, local_id=0 → (i=0, j=4)
i
=
tx
%
16
# 0-15: row position within the 16-row block
j
=
(
tx
//
16
)
*
4
+
local_id
# 0-15: column position
return
i
,
j
shared_16x16_to_local_64x4_layout_m_n
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k
=
shared_16x16_to_local_64x4_layout_A
\ No newline at end of file
tilelang/intrinsics/mmac_macro_generator.py
View file @
f3f31091
...
...
@@ -27,6 +27,11 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
# from .mmac_layout import (
# shared_16x16_to_local_64x4_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A,
# )
lift
=
convert
...
...
@@ -97,6 +102,14 @@ class MatrixCoreIntrinEmitter:
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
# warp_rows: number of iterations per warp to traverse its assigned rows
# Each warp always processes 4 row groups (regardless of block_row_warps)
# block_row_warps controls the stride between row groups, not the count
# stride = block_row_warps * 4
# For example: block_row_warps=4, stride=16, warp processes rows {0-3, 16-19, 32-35, 48-51}
# # warp_cols: same logic for N dimension
# self.warp_rows = 4 # fixed: each warp always iterates 4 times
# self.warp_cols = 4 # fixed: each warp always iterates 4 times
self
.
reduce_k
=
reduce_k
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
...
...
@@ -288,6 +301,29 @@ class MatrixCoreIntrinEmitter:
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
# row_offset: base row offset for this warp
# Each warp processes 4 row groups (e.g., 0-3, 4-7, 8-11, 12-15 for 1 warp)
# warp_row_tiles = block_row_warps * 4 is the stride between row groups
# warp_m (0,1,2,3) determines the starting row group offset ?
# row_offset = warp_m * warp_row_tiles
# if is_transposed:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (rk * chunk + ki * (k_pack * micro_size_k), row_offset + i * warp_row_tiles + row)
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# else:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (row_offset + i * warp_row_tiles + row, rk * chunk + ki * (k_pack * micro_size_k))
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
...
...
@@ -315,6 +351,9 @@ class MatrixCoreIntrinEmitter:
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
# col_offset: base column offset for this warp (similar to row_offset in ldmatrix_a)
# warp_col_tiles = block_col_warps * 4 is the stride between column groups
col_offset
=
warp_n
*
warp_col_tiles
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
...
...
@@ -331,7 +370,7 @@ class MatrixCoreIntrinEmitter:
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment