Commit f3f31091 authored by wangziyang's avatar wangziyang
Browse files

add simple mmac_layout

parent 8443d88e
def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id): from tvm import DataType
i = thread_id % 16 from tvm.runtime import convert
j = local_id + (thread_id // 16) * 4 import tilelang.language as T
return i, j
\ No newline at end of file
'''
256x16 A share mem layout
0 16 32 48
1 17 33 49
2 18 34 50
3 19 35 51
---------------------
4 20 36 52
5 21 37 53
16x16 A Reg layout
0 16 32 48
1 17 33 49
2 18 34 50
3 19 35 51
'''
def shared_16x16_to_local_64x4_layout_A(i, j):
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
# local_id: which element within the thread's 4-element vector
#
# Layout (根据你的图示):
# j=0-3 j=4-7 j=8-11 j=12-15
# T0-T3 T16-T19 T32-T35 T48-T51 (i=0)
# T4-T7 T20-T23 T36-T39 T52-T55 (i=1)
# T8-T11 T24-T27 T40-T43 T56-T59 (i=2)
# T12-T15 T28-T31 T44-T47 T60-T63 (i=3)
thread_id = i + 16 * (j // 4) # 0-63
local_id = j % 4 # 0-3: 4 consecutive elements per thread
return thread_id, local_id
def thread_id_shared_access_64x4_to_16x16_layout_A(tx, local_id):
# tx: thread id within warp (0-63)
# local_id: element index within thread's 4-element vector (0-3)
#
# 根据布局:
# tx=0-15 处理行 0-15 的同一位置,列块 0 (j=0-3)
# tx=16-31 处理行 0-15 的同一位置,列块 1 (j=4-7)
# tx=32-47 处理行 0-15 的同一位置,列块 2 (j=8-11)
# tx=48-63 处理行 0-15 的同一位置,列块 3 (j=12-15)
#
# 例如 tx=0, local_id=0 → (i=0, j=0)
# tx=0, local_id=1 → (i=0, j=1)
# tx=16, local_id=0 → (i=0, j=4)
i = tx % 16 # 0-15: row position within the 16-row block
j = (tx // 16) * 4 + local_id # 0-15: column position
return i, j
shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
\ No newline at end of file
...@@ -27,6 +27,11 @@ from .mfma_layout import ( ...@@ -27,6 +27,11 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B, thread_id_shared_access_64x16_to_16x64_layout_B,
) )
# from .mmac_layout import (
# shared_16x16_to_local_64x4_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A,
# )
lift = convert lift = convert
...@@ -97,6 +102,14 @@ class MatrixCoreIntrinEmitter: ...@@ -97,6 +102,14 @@ class MatrixCoreIntrinEmitter:
self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y self.warp_cols = warp_col_tiles // self.micro_size_y
# warp_rows: number of iterations per warp to traverse its assigned rows
# Each warp always processes 4 row groups (regardless of block_row_warps)
# block_row_warps controls the stride between row groups, not the count
# stride = block_row_warps * 4
# For example: block_row_warps=4, stride=16, warp processes rows {0-3, 16-19, 32-35, 48-51}
# # warp_cols: same logic for N dimension
# self.warp_rows = 4 # fixed: each warp always iterates 4 times
# self.warp_cols = 4 # fixed: each warp always iterates 4 times
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
...@@ -288,6 +301,29 @@ class MatrixCoreIntrinEmitter: ...@@ -288,6 +301,29 @@ class MatrixCoreIntrinEmitter:
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
# row_offset: base row offset for this warp
# Each warp processes 4 row groups (e.g., 0-3, 4-7, 8-11, 12-15 for 1 warp)
# warp_row_tiles = block_row_warps * 4 is the stride between row groups
# warp_m (0,1,2,3) determines the starting row group offset ?
# row_offset = warp_m * warp_row_tiles
# if is_transposed:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (rk * chunk + ki * (k_pack * micro_size_k), row_offset + i * warp_row_tiles + row)
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# else:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (row_offset + i * warp_row_tiles + row, rk * chunk + ki * (k_pack * micro_size_k))
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0): def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols warp_cols = self.warp_cols
...@@ -315,6 +351,9 @@ class MatrixCoreIntrinEmitter: ...@@ -315,6 +351,9 @@ class MatrixCoreIntrinEmitter:
rk=0, rk=0,
): ):
tx, warp_n, _ = self.extract_thread_binding(thread_binding) tx, warp_n, _ = self.extract_thread_binding(thread_binding)
# col_offset: base column offset for this warp (similar to row_offset in ldmatrix_a)
# warp_col_tiles = block_col_warps * 4 is the stride between column groups
col_offset = warp_n * warp_col_tiles
if is_transposed: if is_transposed:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
...@@ -331,7 +370,7 @@ class MatrixCoreIntrinEmitter: ...@@ -331,7 +370,7 @@ class MatrixCoreIntrinEmitter:
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id)) row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
l, r = ( l, r = (
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment