Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bb2f5e4f
Commit
bb2f5e4f
authored
Apr 22, 2026
by
wangziyang
Browse files
[Bugfix] A share to local padding
parent
41887aed
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
17 deletions
+46
-17
tilelang/intrinsics/mmac_layout.py
tilelang/intrinsics/mmac_layout.py
+14
-0
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+30
-15
tilelang/tileop/gemm/gemm_mmac.py
tilelang/tileop/gemm/gemm_mmac.py
+2
-2
No files found.
tilelang/intrinsics/mmac_layout.py
View file @
bb2f5e4f
...
@@ -21,6 +21,20 @@ import tilelang.language as T
...
@@ -21,6 +21,20 @@ import tilelang.language as T
3 19 35 51
3 19 35 51
'''
'''
def
shared_to_local_layout_A_padding
(
row
,
col
,
idx
,
warp_rows
,
block_row_warps
,
tx
):
new_col
=
col
if
warp_rows
>
1
:
inter_idx_padding
=
2
else
:
inter_idx_padding
=
1
paddings
=
inter_idx_padding
*
block_row_warps
*
4
print
(
"paddings:"
,
paddings
)
new_row
=
row
+
paddings
*
((
tx
&
15
)
//
4
)
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
block_row_warps
return
new_row
,
new_col
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
def
shared_16x16_to_local_64x4_layout_A
(
i
,
j
):
# i: row (0-15), j: column (0-15)
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
# thread_id: which thread handles this position
...
...
tilelang/intrinsics/mmac_macro_generator.py
View file @
bb2f5e4f
...
@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
...
@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
from
.mfma_layout
import
(
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_16x4_to_local_64x1_layout_A
,
shared_4x16_to_local_64x1_layout_B
,
shared_4x16_to_local_64x1_layout_B
,
shared_16x16_to_local_64x4_layout_A
,
#
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B
,
shared_16x16_to_local_64x4_layout_B
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_B
,
shared_16x32_to_local_64x8_layout_B
,
...
@@ -19,7 +19,7 @@ from .mfma_layout import (
...
@@ -19,7 +19,7 @@ from .mfma_layout import (
shared_16x64_to_local_64x16_layout_B
,
shared_16x64_to_local_64x16_layout_B
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
#
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
...
@@ -27,10 +27,11 @@ from .mfma_layout import (
...
@@ -27,10 +27,11 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B
,
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
)
# from .mmac_layout import (
from
.mmac_layout
import
(
# shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_A
,
# thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_A
,
# )
shared_to_local_layout_A_padding
)
lift
=
convert
lift
=
convert
...
@@ -251,6 +252,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -251,6 +252,7 @@ class MatrixCoreIntrinEmitter:
)
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
# share mem a needs warp number
# share mem a needs warp number
warp_num
=
self
.
block_row_warps
warp_num
=
self
.
block_row_warps
...
@@ -272,6 +274,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -272,6 +274,7 @@ class MatrixCoreIntrinEmitter:
A_buf
=
A_region
.
buffer
A_buf
=
A_region
.
buffer
A_base0
=
A_region
.
region
[
-
2
].
min
A_base0
=
A_region
.
region
[
-
2
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
print
(
"A_base0, A_base1:"
,
A_base0
,
A_base1
)
@
T
.
macro
@
T
.
macro
def
_warp_ldmatrix_a
(
def
_warp_ldmatrix_a
(
...
@@ -281,31 +284,43 @@ class MatrixCoreIntrinEmitter:
...
@@ -281,31 +284,43 @@ class MatrixCoreIntrinEmitter:
thread_binding
,
thread_binding
,
rk
=
0
,
rk
=
0
,
):
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
# warp_n[0-256] -> {0,1,2,3}
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
# {0..3,16..19,32..35,48..51} -> 0
# {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1
# {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2
# {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx
=
(
tx
&
15
)
>>
2
warp_interval_idx
=
(
tx
&
15
)
>>
2
warp_group_idx
=
(
tx
//
32
)
warp_group_idx
=
warp_n
warp_inter_stride
=
4
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if
is_transposed
:
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
# 每轮初始位置行偏移
# 每轮初始位置行偏移
row
+=
i
*
warp_row_init
# row += i * warp_row_init
# warp 组行间隔
# # warp 组行间隔
row
+=
warp_group_idx
*
4
# row += warp_group_idx * 4
# warp 内行间隔
# # warp 内行间隔
row
+=
warp_interval_idx
*
warp_row_interval
# row += warp_interval_idx * warp_row_interval
raise
NotImplementedError
(
"Transposed A with preshuffle is not implemented yet"
)
row
,
col
=
shared_to_local_layout_A_padding
(
row
,
col
,
i
,
warp_rows
,
self
.
block_row_warps
,
tx
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
# # 每轮初始位置行偏移
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
row
,
col
=
shared_to_local_layout_A_padding
(
row
,
col
,
i
,
warp_rows
,
self
.
block_row_warps
,
tx
)
print
(
"row, col:"
,
row
,
col
)
l
,
r
=
(
warp_m
*
warp_inter_stride
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
tilelang/tileop/gemm/gemm_mmac.py
View file @
bb2f5e4f
...
@@ -32,8 +32,8 @@ class GemmMMAC(GemmBase):
...
@@ -32,8 +32,8 @@ class GemmMMAC(GemmBase):
if
self
.
is_gemm_ss
():
if
self
.
is_gemm_ss
():
return
{
return
{
self
.
A
:
make_
swizzled
_layout
(
self
.
A
),
self
.
A
:
make_
linear
_layout
(
self
.
A
),
self
.
B
:
make_
swizzled
_layout
(
self
.
B
),
self
.
B
:
make_
linear
_layout
(
self
.
B
),
self
.
C
:
mmac_emitter
.
make_mmac_store_layout
(
self
.
C
),
self
.
C
:
mmac_emitter
.
make_mmac_store_layout
(
self
.
C
),
}
}
elif
self
.
is_gemm_sr
():
elif
self
.
is_gemm_sr
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment