Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
ad92620d
Commit
ad92620d
authored
Apr 27, 2026
by
wangziyang
Browse files
print B warp layout & B s_to_l padding
parent
b8213492
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
100 additions
and
11 deletions
+100
-11
tilelang/intrinsics/mmac_layout.py
tilelang/intrinsics/mmac_layout.py
+100
-11
No files found.
tilelang/intrinsics/mmac_layout.py
View file @
ad92620d
from
tvm
import
DataType
#
from tvm import DataType
from
tvm.runtime
import
convert
#
from tvm.runtime import convert
import
tilelang.language
as
T
import
tilelang.language
as
T
...
@@ -29,7 +29,7 @@ def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps,
...
@@ -29,7 +29,7 @@ def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps,
inter_idx_padding
=
1
inter_idx_padding
=
1
paddings
=
inter_idx_padding
*
block_row_warps
*
4
paddings
=
inter_idx_padding
*
block_row_warps
*
4
print
(
"paddings:"
,
paddings
)
#
print("paddings:", paddings)
new_row
=
row
+
paddings
*
((
tx
&
15
)
//
4
)
new_row
=
row
+
paddings
*
((
tx
&
15
)
//
4
)
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
block_row_warps
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
block_row_warps
...
@@ -39,14 +39,23 @@ def shared_to_local_layout_B_padding(row, col, idx, warp_cols, block_col_warps,
...
@@ -39,14 +39,23 @@ def shared_to_local_layout_B_padding(row, col, idx, warp_cols, block_col_warps,
# ???
# ???
# T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ...
# T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ...
# T0 T1 T2
# T0 T1 T2
new_col
=
((
col
%
16
*
8
)
//
32
*
8
)
+
(
col
//
16
)
*
4
+
(
row
//
4
)
new_col
=
((
tx
%
4
)
*
8
)
+
(
col
//
16
)
*
4
+
(
row
%
4
)
# new_col = ((col % 16 * 8 ) // 32 * 8) + (col // 16) * 4 + (row // 4)
# swap per 4 row idx1 <-> idx2
# swap per 4 row idx1 <-> idx2
bit0
=
(
tx
//
4
)
&
1
row_tmp
=
tx
//
4
bit1
=
(
tx
//
2
)
&
1
bit0
=
row_tmp
&
1
new_row
=
(
tx
//
4
)
&
(
~
3
)
+
(
bit0
*
2
)
+
bit1
bit1
=
(
row_tmp
>>
1
)
&
1
new_row
=
((
row_tmp
//
4
)
<<
2
)
+
(
bit0
*
2
)
+
bit1
# return new_row, new_col
if
block_col_warps
>
1
:
inter_idx_padding
=
2
else
:
inter_idx_padding
=
1
paddings
=
32
paddings
=
32
new_col
+=
(
idx
&
1
)
*
paddings
new_col
+=
idx
*
inter_idx_padding
*
paddings
return
new_row
,
new_col
return
new_row
,
new_col
...
@@ -90,8 +99,9 @@ def shared_32x16_to_local_64x8_layout_B(i, j):
...
@@ -90,8 +99,9 @@ def shared_32x16_to_local_64x8_layout_B(i, j):
# row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31
# row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31
# row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47
# row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47
# row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63
# row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63
thread_id
=
16
*
(
i
//
4
)
thread_id
=
16
*
(
i
//
4
)
+
j
%
16
local_id
=
(
j
//
16
)
*
4
+
(
i
%
4
)
local_id
=
(
j
//
16
)
*
4
+
(
i
%
4
)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
return
thread_id
,
local_id
return
thread_id
,
local_id
...
@@ -107,3 +117,82 @@ shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
...
@@ -107,3 +117,82 @@ shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k
=
shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k
=
shared_16x16_to_local_64x4_layout_A
shared_32x16_to_local_64x8_layout_n_m
=
shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_n_m
=
shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_k_n
=
shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_k_n
=
shared_32x16_to_local_64x8_layout_B
if
__name__
==
"__main__"
:
idx
=
1
warp_cols
=
1
block_col_warps
=
1
row
=
16
col
=
32
tx_f
=
0
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for i in range(row):
# print("{:^2}".format(i),end=" ")
# for j in range(col):
# thread_id,local_id = shared_32x16_to_local_64x8_layout_B(i,j)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for j in range(row):
# local_idx = j % 4
# print("{:^2}".format(j),end=" ")
# for i in range(col):
# tx = (j // 4) * 16 + (i % 16)
# local_id = (i // 16) * 4 + local_idx
# row_idx,col_idx = thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id)
# print("T{:^2} V{:^2} ".format(tx,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
tx
=
16
# row_ = 4
col_
=
16
group
=
4
results
=
[]
for
k
in
range
(
group
):
for
j
in
range
(
col_
):
tx
=
j
+
k
*
16
for
i
in
range
(
k
*
4
,(
k
+
1
)
*
4
):
new_row
,
new_col
=
shared_to_local_layout_B_padding
(
i
,
j
,
idx
,
warp_cols
,
block_col_warps
,
tx
)
results
.
append
((
new_row
,
new_col
,(
tx
,
i
-
k
*
4
)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
for
i
in
range
(
k
*
4
,(
k
+
1
)
*
4
):
new_row
,
new_col
=
shared_to_local_layout_B_padding
(
i
,
j
+
16
,
idx
,
warp_cols
,
block_col_warps
,
tx
)
results
.
append
((
new_row
,
new_col
,(
tx
,
i
-
k
*
4
+
4
)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
# 转换为嵌套字典
grid
=
{}
for
r
,
c
,
t
in
results
:
if
r
not
in
grid
:
grid
[
r
]
=
{}
grid
[
r
][
c
]
=
t
print
(
"layout B padding ..."
)
print
(
" "
,
end
=
" "
)
col
=
32
for
i
in
range
(
col
):
print
(
"{:^8}"
.
format
(
i
),
end
=
" "
)
print
()
# 确定行遍历范围
rows
=
sorted
(
grid
.
keys
())
# 按照 行范围 -> 列范围 遍历
for
r
in
rows
:
print
(
"{:^2}"
.
format
(
r
),
end
=
" "
)
#获取当前行下列号并排序
cols
=
sorted
(
grid
[
r
].
keys
())
for
c
in
cols
:
tx
,
id
=
grid
[
r
][
c
]
print
(
"T{:^2} V{:^2} "
.
format
(
tx
,
id
),
end
=
" "
)
print
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment