Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
b8213492
"container/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "2ee29443b6ec02d875d1e5bdf1c47667123f4be4"
Commit
b8213492
authored
Apr 24, 2026
by
wangziyang
Browse files
add B mmac layout
parent
bb2f5e4f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
7 deletions
+53
-7
tilelang/engine/phase.py
tilelang/engine/phase.py
+2
-2
tilelang/intrinsics/mmac_layout.py
tilelang/intrinsics/mmac_layout.py
+39
-1
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+12
-4
No files found.
tilelang/engine/phase.py
View file @
b8213492
...
@@ -233,7 +233,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -233,7 +233,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tilelang
.
transform
.
VectorizeLoop
(
enable_vectorize
=
allow_vectorize
(
pass_ctx
=
pass_ctx
))(
mod
)
mod
=
tilelang
.
transform
.
VectorizeLoop
(
enable_vectorize
=
allow_vectorize
(
pass_ctx
=
pass_ctx
))(
mod
)
# Transform B_local layout from shared memory thread-interleaved to local row-major
# Transform B_local layout from shared memory thread-interleaved to local row-major
mod
=
tilelang
.
transform
.
InjectBLocalLayoutTransform
()(
mod
)
#
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod
=
tilelang
.
transform
.
StorageRewrite
()(
mod
)
mod
=
tilelang
.
transform
.
StorageRewrite
()(
mod
)
mod
=
tir
.
transform
.
UnrollLoop
()(
mod
)
mod
=
tir
.
transform
.
UnrollLoop
()(
mod
)
mod
=
tir
.
transform
.
RenormalizeSplitPattern
()(
mod
)
mod
=
tir
.
transform
.
RenormalizeSplitPattern
()(
mod
)
...
@@ -275,7 +275,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -275,7 +275,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# as ptx async copy won't be recognized as a valid buffer load
# as ptx async copy won't be recognized as a valid buffer load
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
# Inject ds_read for shared to register memory copy on DCU
# Inject ds_read for shared to register memory copy on DCU
mod
=
tilelang
.
transform
.
InjectDSRead
()(
mod
)
#
mod = tilelang.transform.InjectDSRead()(mod)
print
(
"222222222"
)
print
(
"222222222"
)
print
(
mod
)
print
(
mod
)
...
...
tilelang/intrinsics/mmac_layout.py
View file @
b8213492
...
@@ -35,6 +35,21 @@ def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps,
...
@@ -35,6 +35,21 @@ def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps,
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
block_row_warps
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
block_row_warps
return
new_row
,
new_col
return
new_row
,
new_col
def
shared_to_local_layout_B_padding
(
row
,
col
,
idx
,
warp_cols
,
block_col_warps
,
tx
):
# ???
# T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ...
# T0 T1 T2
new_col
=
((
col
%
16
*
8
)
//
32
*
8
)
+
(
col
//
16
)
*
4
+
(
row
//
4
)
# swap per 4 row idx1 <-> idx2
bit0
=
(
tx
//
4
)
&
1
bit1
=
(
tx
//
2
)
&
1
new_row
=
(
tx
//
4
)
&
(
~
3
)
+
(
bit0
*
2
)
+
bit1
paddings
=
32
new_col
+=
(
idx
&
1
)
*
paddings
return
new_row
,
new_col
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
# i: row (0-15), j: column (0-15)
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
# thread_id: which thread handles this position
...
@@ -67,5 +82,28 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(tx, local_id):
...
@@ -67,5 +82,28 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(tx, local_id):
j
=
(
tx
//
16
)
*
4
+
local_id
# 0-15: column position
j
=
(
tx
//
16
)
*
4
+
local_id
# 0-15: column position
return
i
,
j
return
i
,
j
def
shared_32x16_to_local_64x8_layout_B
(
i
,
j
):
# i: row (0-15) j : col (0-31)
# Layout:
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# row=0-3 T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15
# row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31
# row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47
# row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63
thread_id
=
16
*
(
i
//
4
)
local_id
=
(
j
//
16
)
*
4
+
(
i
%
4
)
return
thread_id
,
local_id
def
thread_id_shared_access_64x8_to_32x16_layout_B
(
tx
,
local_id
):
# tx: thread id within warp (0-63)
# local_id: element index within thread's 4-element vector (0-7)
# i : row (0-15) j : col (0-31)
i
=
(
tx
//
15
)
*
4
+
local_id
%
4
j
=
(
tx
%
15
)
+
(
local_id
//
4
)
*
16
return
i
,
j
shared_16x16_to_local_64x4_layout_m_n
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_m_n
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k
=
shared_16x16_to_local_64x4_layout_A
\ No newline at end of file
shared_32x16_to_local_64x8_layout_n_m
=
shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_k_n
=
shared_32x16_to_local_64x8_layout_B
\ No newline at end of file
tilelang/intrinsics/mmac_macro_generator.py
View file @
b8213492
...
@@ -30,7 +30,10 @@ from .mfma_layout import (
...
@@ -30,7 +30,10 @@ from .mfma_layout import (
from
.mmac_layout
import
(
from
.mmac_layout
import
(
shared_16x16_to_local_64x4_layout_A
,
shared_16x16_to_local_64x4_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
shared_to_local_layout_A_padding
shared_to_local_layout_A_padding
,
shared_32x16_to_local_64x8_layout_B
,
thread_id_shared_access_64x8_to_32x16_layout_B
,
shared_to_local_layout_B_padding
)
)
lift
=
convert
lift
=
convert
...
@@ -295,6 +298,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -295,6 +298,7 @@ class MatrixCoreIntrinEmitter:
warp_inter_stride
=
4
warp_inter_stride
=
4
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if
is_transposed
:
if
is_transposed
:
raise
NotImplementedError
(
"Transposed A is not implemented yet"
)
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
...
@@ -304,7 +308,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -304,7 +308,6 @@ class MatrixCoreIntrinEmitter:
# row += warp_group_idx * 4
# row += warp_group_idx * 4
# # warp 内行间隔
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
# row += warp_interval_idx * warp_row_interval
raise
NotImplementedError
(
"Transposed A with preshuffle is not implemented yet"
)
row
,
col
=
shared_to_local_layout_A_padding
(
row
,
col
,
i
,
warp_rows
,
self
.
block_row_warps
,
tx
)
row
,
col
=
shared_to_local_layout_A_padding
(
row
,
col
,
i
,
warp_rows
,
self
.
block_row_warps
,
tx
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
...
@@ -330,6 +333,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -330,6 +333,7 @@ class MatrixCoreIntrinEmitter:
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
chunk
=
self
.
chunk
micro_size_y
=
self
.
micro_size_y
micro_size_y
=
self
.
micro_size_y
micro_size_k
=
self
.
micro_size_k
micro_size_k
=
self
.
micro_size_k
local_size_b
=
self
.
local_size_b
local_size_b
=
self
.
local_size_b
k_pack
=
self
.
k_pack
k_pack
=
self
.
k_pack
...
@@ -351,8 +355,10 @@ class MatrixCoreIntrinEmitter:
...
@@ -351,8 +355,10 @@ class MatrixCoreIntrinEmitter:
thread_binding
,
thread_binding
,
rk
=
0
,
rk
=
0
,
):
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
warp_inter_stride
=
32
if
is_transposed
:
if
is_transposed
:
raise
NotImplementedError
(
"Transposed B is not implemented yet"
)
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
...
@@ -366,9 +372,11 @@ class MatrixCoreIntrinEmitter:
...
@@ -366,9 +372,11 @@ class MatrixCoreIntrinEmitter:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
shared_to_local_layout_B_padding
(
row
,
col
,
j
,
warp_cols
,
self
.
block_col_warps
,
tx
)
l
,
r
=
(
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
# warp_n * warp_col_tiles + j * micro_size_y,
warp_n
*
warp_inter_stride
,
)
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment