Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
64179eaf
Commit
64179eaf
authored
Apr 27, 2026
by
qisan
Browse files
Merge remote ds_read: resolve conflicts in phase.py, mmac_macro_generator.py, gemm_mmac.py
parents
dd91b1e0
ad92620d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
166 additions
and
27 deletions
+166
-27
tilelang/engine/phase.py
tilelang/engine/phase.py
+6
-2
tilelang/intrinsics/mmac_layout.py
tilelang/intrinsics/mmac_layout.py
+144
-3
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+16
-20
tilelang/tileop/gemm/gemm_mmac.py
tilelang/tileop/gemm/gemm_mmac.py
+0
-2
No files found.
tilelang/engine/phase.py
View file @
64179eaf
...
@@ -236,7 +236,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -236,7 +236,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tilelang
.
transform
.
VectorizeLoop
(
enable_vectorize
=
allow_vectorize
(
pass_ctx
=
pass_ctx
))(
mod
)
mod
=
tilelang
.
transform
.
VectorizeLoop
(
enable_vectorize
=
allow_vectorize
(
pass_ctx
=
pass_ctx
))(
mod
)
# Transform B_local layout from shared memory thread-interleaved to local row-major
# Transform B_local layout from shared memory thread-interleaved to local row-major
mod
=
tilelang
.
transform
.
InjectBLocalLayoutTransform
()(
mod
)
#
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod
=
tilelang
.
transform
.
StorageRewrite
()(
mod
)
mod
=
tilelang
.
transform
.
StorageRewrite
()(
mod
)
mod
=
tir
.
transform
.
UnrollLoop
()(
mod
)
mod
=
tir
.
transform
.
UnrollLoop
()(
mod
)
mod
=
tir
.
transform
.
RenormalizeSplitPattern
()(
mod
)
mod
=
tir
.
transform
.
RenormalizeSplitPattern
()(
mod
)
...
@@ -282,9 +282,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -282,9 +282,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject ds_read for shared to register memory copy on DCU
# Inject ds_read for shared to register memory copy on DCU
# mod = tilelang.transform.InjectDSRead()(mod)
# mod = tilelang.transform.InjectDSRead()(mod)
# mod = tilelang.transform.InjectDSRead()(mod)
print
(
"222222222"
)
print
(
mod
)
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
mod
=
tilelang
.
transform
.
MakePackedAPI
()(
mod
)
mod
=
tilelang
.
transform
.
MakePackedAPI
()(
mod
)
mod
=
tilelang
.
transform
.
Simplify
()(
mod
)
mod
=
tilelang
.
transform
.
Simplify
()(
mod
)
mod
=
tilelang
.
transform
.
LowerDeviceKernelLaunch
()(
mod
)
mod
=
tilelang
.
transform
.
LowerDeviceKernelLaunch
()(
mod
)
...
...
tilelang/intrinsics/mmac_layout.py
View file @
64179eaf
from
tvm
import
DataType
#
from tvm import DataType
from
tvm.runtime
import
convert
#
from tvm.runtime import convert
import
tilelang.language
as
T
import
tilelang.language
as
T
...
@@ -21,6 +21,44 @@ import tilelang.language as T
...
@@ -21,6 +21,44 @@ import tilelang.language as T
3 19 35 51
3 19 35 51
'''
'''
def
shared_to_local_layout_A_padding
(
row
,
col
,
idx
,
warp_rows
,
block_row_warps
,
tx
):
new_col
=
col
if
warp_rows
>
1
:
inter_idx_padding
=
2
else
:
inter_idx_padding
=
1
paddings
=
inter_idx_padding
*
block_row_warps
*
4
# print("paddings:", paddings)
new_row
=
row
+
paddings
*
((
tx
&
15
)
//
4
)
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
block_row_warps
return
new_row
,
new_col
def
shared_to_local_layout_B_padding
(
row
,
col
,
idx
,
warp_cols
,
block_col_warps
,
tx
):
# ???
# T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ...
# T0 T1 T2
new_col
=
((
tx
%
4
)
*
8
)
+
(
col
//
16
)
*
4
+
(
row
%
4
)
# new_col = ((col % 16 * 8 ) // 32 * 8) + (col // 16) * 4 + (row // 4)
# swap per 4 row idx1 <-> idx2
row_tmp
=
tx
//
4
bit0
=
row_tmp
&
1
bit1
=
(
row_tmp
>>
1
)
&
1
new_row
=
((
row_tmp
//
4
)
<<
2
)
+
(
bit0
*
2
)
+
bit1
# return new_row, new_col
if
block_col_warps
>
1
:
inter_idx_padding
=
2
else
:
inter_idx_padding
=
1
paddings
=
32
new_col
+=
idx
*
inter_idx_padding
*
paddings
return
new_row
,
new_col
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
# i: row (0-15), j: column (0-15)
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
# thread_id: which thread handles this position
...
@@ -53,5 +91,108 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(tx, local_id):
...
@@ -53,5 +91,108 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(tx, local_id):
j
=
(
tx
//
16
)
*
4
+
local_id
# 0-15: column position
j
=
(
tx
//
16
)
*
4
+
local_id
# 0-15: column position
return
i
,
j
return
i
,
j
def
shared_32x16_to_local_64x8_layout_B
(
i
,
j
):
# i: row (0-15) j : col (0-31)
# Layout:
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# row=0-3 T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15
# row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31
# row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47
# row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63
thread_id
=
16
*
(
i
//
4
)
+
j
%
16
local_id
=
(
j
//
16
)
*
4
+
(
i
%
4
)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
return
thread_id
,
local_id
def
thread_id_shared_access_64x8_to_32x16_layout_B
(
tx
,
local_id
):
# tx: thread id within warp (0-63)
# local_id: element index within thread's 4-element vector (0-7)
# i : row (0-15) j : col (0-31)
i
=
(
tx
//
15
)
*
4
+
local_id
%
4
j
=
(
tx
%
15
)
+
(
local_id
//
4
)
*
16
return
i
,
j
shared_16x16_to_local_64x4_layout_m_n
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_m_n
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k
=
shared_16x16_to_local_64x4_layout_A
\ No newline at end of file
shared_32x16_to_local_64x8_layout_n_m
=
shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_k_n
=
shared_32x16_to_local_64x8_layout_B
if
__name__
==
"__main__"
:
idx
=
1
warp_cols
=
1
block_col_warps
=
1
row
=
16
col
=
32
tx_f
=
0
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for i in range(row):
# print("{:^2}".format(i),end=" ")
# for j in range(col):
# thread_id,local_id = shared_32x16_to_local_64x8_layout_B(i,j)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for j in range(row):
# local_idx = j % 4
# print("{:^2}".format(j),end=" ")
# for i in range(col):
# tx = (j // 4) * 16 + (i % 16)
# local_id = (i // 16) * 4 + local_idx
# row_idx,col_idx = thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id)
# print("T{:^2} V{:^2} ".format(tx,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
tx
=
16
# row_ = 4
col_
=
16
group
=
4
results
=
[]
for
k
in
range
(
group
):
for
j
in
range
(
col_
):
tx
=
j
+
k
*
16
for
i
in
range
(
k
*
4
,(
k
+
1
)
*
4
):
new_row
,
new_col
=
shared_to_local_layout_B_padding
(
i
,
j
,
idx
,
warp_cols
,
block_col_warps
,
tx
)
results
.
append
((
new_row
,
new_col
,(
tx
,
i
-
k
*
4
)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
for
i
in
range
(
k
*
4
,(
k
+
1
)
*
4
):
new_row
,
new_col
=
shared_to_local_layout_B_padding
(
i
,
j
+
16
,
idx
,
warp_cols
,
block_col_warps
,
tx
)
results
.
append
((
new_row
,
new_col
,(
tx
,
i
-
k
*
4
+
4
)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
# 转换为嵌套字典
grid
=
{}
for
r
,
c
,
t
in
results
:
if
r
not
in
grid
:
grid
[
r
]
=
{}
grid
[
r
][
c
]
=
t
print
(
"layout B padding ..."
)
print
(
" "
,
end
=
" "
)
col
=
32
for
i
in
range
(
col
):
print
(
"{:^8}"
.
format
(
i
),
end
=
" "
)
print
()
# 确定行遍历范围
rows
=
sorted
(
grid
.
keys
())
# 按照 行范围 -> 列范围 遍历
for
r
in
rows
:
print
(
"{:^2}"
.
format
(
r
),
end
=
" "
)
#获取当前行下列号并排序
cols
=
sorted
(
grid
[
r
].
keys
())
for
c
in
cols
:
tx
,
id
=
grid
[
r
][
c
]
print
(
"T{:^2} V{:^2} "
.
format
(
tx
,
id
),
end
=
" "
)
print
()
\ No newline at end of file
tilelang/intrinsics/mmac_macro_generator.py
View file @
64179eaf
...
@@ -30,6 +30,10 @@ from .mfma_layout import (
...
@@ -30,6 +30,10 @@ from .mfma_layout import (
from
.mmac_layout
import
(
from
.mmac_layout
import
(
shared_16x16_to_local_64x4_layout_A
,
shared_16x16_to_local_64x4_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
shared_to_local_layout_A_padding
,
shared_32x16_to_local_64x8_layout_B
,
thread_id_shared_access_64x8_to_32x16_layout_B
,
shared_to_local_layout_B_padding
)
)
lift
=
convert
lift
=
convert
...
@@ -250,20 +254,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -250,20 +254,6 @@ class MatrixCoreIntrinEmitter:
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
)
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
map_64x16
(
self
,
row
,
col
,
idx
,
warp_rows
,
tx
):
new_col
=
col
if
warp_rows
>
1
:
inter_idx_padding
=
2
else
:
inter_idx_padding
=
1
paddings
=
inter_idx_padding
*
self
.
block_row_warps
*
4
print
(
"paddings:"
,
paddings
)
new_row
=
row
+
paddings
*
((
tx
&
15
)
//
4
)
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
self
.
block_row_warps
return
new_row
,
new_col
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
...
@@ -305,8 +295,10 @@ class MatrixCoreIntrinEmitter:
...
@@ -305,8 +295,10 @@ class MatrixCoreIntrinEmitter:
# {12..15,28..31,44..47,60..63} -> 3
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx
=
(
tx
&
15
)
>>
2
warp_interval_idx
=
(
tx
&
15
)
>>
2
warp_group_idx
=
warp_n
warp_group_idx
=
warp_n
warp_inter_stride
=
4
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if
is_transposed
:
if
is_transposed
:
raise
NotImplementedError
(
"Transposed A is not implemented yet"
)
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
...
@@ -316,8 +308,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -316,8 +308,7 @@ class MatrixCoreIntrinEmitter:
# row += warp_group_idx * 4
# row += warp_group_idx * 4
# # warp 内行间隔
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
# row += warp_interval_idx * warp_row_interval
raise
NotImplementedError
(
"Transposed A with preshuffle is not implemented yet"
)
row
,
col
=
shared_to_local_layout_A_padding
(
row
,
col
,
i
,
warp_rows
,
self
.
block_row_warps
,
tx
)
row
,
col
=
self
.
map_64x16
(
row
,
col
,
i
,
warp_rows
,
tx
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
else
:
...
@@ -330,9 +321,9 @@ class MatrixCoreIntrinEmitter:
...
@@ -330,9 +321,9 @@ class MatrixCoreIntrinEmitter:
# row += warp_group_idx * 4
# row += warp_group_idx * 4
# # warp 内行间隔
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
# row += warp_interval_idx * warp_row_interval
row
,
col
=
s
elf
.
map_64x16
(
row
,
col
,
i
,
warp_row
s
,
tx
)
row
,
col
=
s
hared_to_local_layout_A_padding
(
row
,
col
,
i
,
warp_rows
,
self
.
block_row_warp
s
,
tx
)
print
(
"row, col:"
,
row
,
col
)
print
(
"row, col:"
,
row
,
col
)
l
,
r
=
(
warp_m
*
4
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
l
,
r
=
(
warp_m
*
warp_inter_stride
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
@@ -342,6 +333,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -342,6 +333,7 @@ class MatrixCoreIntrinEmitter:
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
chunk
=
self
.
chunk
micro_size_y
=
self
.
micro_size_y
micro_size_y
=
self
.
micro_size_y
micro_size_k
=
self
.
micro_size_k
micro_size_k
=
self
.
micro_size_k
local_size_b
=
self
.
local_size_b
local_size_b
=
self
.
local_size_b
k_pack
=
self
.
k_pack
k_pack
=
self
.
k_pack
...
@@ -363,8 +355,10 @@ class MatrixCoreIntrinEmitter:
...
@@ -363,8 +355,10 @@ class MatrixCoreIntrinEmitter:
thread_binding
,
thread_binding
,
rk
=
0
,
rk
=
0
,
):
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
warp_inter_stride
=
32
if
is_transposed
:
if
is_transposed
:
raise
NotImplementedError
(
"Transposed B is not implemented yet"
)
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
...
@@ -378,9 +372,11 @@ class MatrixCoreIntrinEmitter:
...
@@ -378,9 +372,11 @@ class MatrixCoreIntrinEmitter:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
shared_to_local_layout_B_padding
(
row
,
col
,
j
,
warp_cols
,
self
.
block_col_warps
,
tx
)
l
,
r
=
(
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
# warp_n * warp_col_tiles + j * micro_size_y,
warp_n
*
warp_inter_stride
,
)
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
...
...
tilelang/tileop/gemm/gemm_mmac.py
View file @
64179eaf
...
@@ -32,8 +32,6 @@ class GemmMMAC(GemmBase):
...
@@ -32,8 +32,6 @@ class GemmMMAC(GemmBase):
if
self
.
is_gemm_ss
():
if
self
.
is_gemm_ss
():
return
{
return
{
# self.A: make_swizzled_layout(self.A, allow_pad=False),
# self.B: make_swizzled_layout(self.B, allow_pad=False),
self
.
A
:
make_linear_layout
(
self
.
A
),
self
.
A
:
make_linear_layout
(
self
.
A
),
self
.
B
:
make_linear_layout
(
self
.
B
),
self
.
B
:
make_linear_layout
(
self
.
B
),
self
.
C
:
mmac_emitter
.
make_mmac_store_layout
(
self
.
C
),
self
.
C
:
mmac_emitter
.
make_mmac_store_layout
(
self
.
C
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment