Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
a0ec0f57
Commit
a0ec0f57
authored
Apr 21, 2026
by
wangziyang
Browse files
add ldmatrix warp_interval_idx mapping
parent
f3f31091
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
37 deletions
+13
-37
tilelang/intrinsics/mmac_layout.py
tilelang/intrinsics/mmac_layout.py
+1
-1
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+12
-36
No files found.
tilelang/intrinsics/mmac_layout.py
View file @
a0ec0f57
...
@@ -26,7 +26,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j):
...
@@ -26,7 +26,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j):
# thread_id: which thread handles this position
# thread_id: which thread handles this position
# local_id: which element within the thread's 4-element vector
# local_id: which element within the thread's 4-element vector
#
#
# Layout
(根据你的图示)
:
# Layout :
# j=0-3 j=4-7 j=8-11 j=12-15
# j=0-3 j=4-7 j=8-11 j=12-15
# T0-T3 T16-T19 T32-T35 T48-T51 (i=0)
# T0-T3 T16-T19 T32-T35 T48-T51 (i=0)
# T4-T7 T20-T23 T36-T39 T52-T55 (i=1)
# T4-T7 T20-T23 T36-T39 T52-T55 (i=1)
...
...
tilelang/intrinsics/mmac_macro_generator.py
View file @
a0ec0f57
...
@@ -102,14 +102,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -102,14 +102,6 @@ class MatrixCoreIntrinEmitter:
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
# warp_rows: number of iterations per warp to traverse its assigned rows
# Each warp always processes 4 row groups (regardless of block_row_warps)
# block_row_warps controls the stride between row groups, not the count
# stride = block_row_warps * 4
# For example: block_row_warps=4, stride=16, warp processes rows {0-3, 16-19, 32-35, 48-51}
# # warp_cols: same logic for N dimension
# self.warp_rows = 4 # fixed: each warp always iterates 4 times
# self.warp_cols = 4 # fixed: each warp always iterates 4 times
self
.
reduce_k
=
reduce_k
self
.
reduce_k
=
reduce_k
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
num_elems_per_byte
=
num_elems_per_byte
...
@@ -260,8 +252,11 @@ class MatrixCoreIntrinEmitter:
...
@@ -260,8 +252,11 @@ class MatrixCoreIntrinEmitter:
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
# share mem a needs warp number
warp_num
=
self
.
block_row_warps
warp_row_tiles
=
self
.
warp_row_tiles
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_row_interval
=
warp_num
*
1
*
4
if
warp_rows
==
1
else
warp_num
*
2
*
4
chunk
=
self
.
chunk
chunk
=
self
.
chunk
micro_size_x
=
self
.
micro_size_x
micro_size_x
=
self
.
micro_size_x
micro_size_k
=
self
.
micro_size_k
micro_size_k
=
self
.
micro_size_k
...
@@ -286,10 +281,17 @@ class MatrixCoreIntrinEmitter:
...
@@ -286,10 +281,17 @@ class MatrixCoreIntrinEmitter:
rk
=
0
,
rk
=
0
,
):
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
# {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx
=
(
tx
&
15
)
>>
2
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if
is_transposed
:
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
+=
warp_interval_idx
*
warp_row_interval
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
else
:
...
@@ -299,30 +301,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -299,30 +301,7 @@ class MatrixCoreIntrinEmitter:
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
# row_offset: base row offset for this warp
# Each warp processes 4 row groups (e.g., 0-3, 4-7, 8-11, 12-15 for 1 warp)
# warp_row_tiles = block_row_warps * 4 is the stride between row groups
# warp_m (0,1,2,3) determines the starting row group offset ?
# row_offset = warp_m * warp_row_tiles
# if is_transposed:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (rk * chunk + ki * (k_pack * micro_size_k), row_offset + i * warp_row_tiles + row)
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# else:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (row_offset + i * warp_row_tiles + row, rk * chunk + ki * (k_pack * micro_size_k))
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
def
ldmatrix_b
(
self
,
B_local_buf
,
B_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_col_tiles
=
self
.
warp_col_tiles
...
@@ -351,9 +330,6 @@ class MatrixCoreIntrinEmitter:
...
@@ -351,9 +330,6 @@ class MatrixCoreIntrinEmitter:
rk
=
0
,
rk
=
0
,
):
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
# col_offset: base column offset for this warp (similar to row_offset in ldmatrix_a)
# warp_col_tiles = block_col_warps * 4 is the stride between column groups
col_offset
=
warp_n
*
warp_col_tiles
if
is_transposed
:
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
...
@@ -370,7 +346,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -370,7 +346,7 @@ class MatrixCoreIntrinEmitter:
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
,
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment