Commit a0ec0f57 authored by wangziyang's avatar wangziyang
Browse files

add ldmatrix warp_interval_idx mapping

parent f3f31091
...@@ -26,7 +26,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j): ...@@ -26,7 +26,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j):
# thread_id: which thread handles this position # thread_id: which thread handles this position
# local_id: which element within the thread's 4-element vector # local_id: which element within the thread's 4-element vector
# #
# Layout (根据你的图示): # Layout :
# j=0-3 j=4-7 j=8-11 j=12-15 # j=0-3 j=4-7 j=8-11 j=12-15
# T0-T3 T16-T19 T32-T35 T48-T51 (i=0) # T0-T3 T16-T19 T32-T35 T48-T51 (i=0)
# T4-T7 T20-T23 T36-T39 T52-T55 (i=1) # T4-T7 T20-T23 T36-T39 T52-T55 (i=1)
......
...@@ -102,14 +102,6 @@ class MatrixCoreIntrinEmitter: ...@@ -102,14 +102,6 @@ class MatrixCoreIntrinEmitter:
self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y self.warp_cols = warp_col_tiles // self.micro_size_y
# warp_rows: number of iterations per warp to traverse its assigned rows
# Each warp always processes 4 row groups (regardless of block_row_warps)
# block_row_warps controls the stride between row groups, not the count
# stride = block_row_warps * 4
# For example: block_row_warps=4, stride=16, warp processes rows {0-3, 16-19, 32-35, 48-51}
# # warp_cols: same logic for N dimension
# self.warp_rows = 4 # fixed: each warp always iterates 4 times
# self.warp_cols = 4 # fixed: each warp always iterates 4 times
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
...@@ -260,8 +252,11 @@ class MatrixCoreIntrinEmitter: ...@@ -260,8 +252,11 @@ class MatrixCoreIntrinEmitter:
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
# share mem a needs warp number
warp_num = self.block_row_warps
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows warp_rows = self.warp_rows
warp_row_interval = warp_num * 1 * 4 if warp_rows == 1 else warp_num * 2 * 4
chunk = self.chunk chunk = self.chunk
micro_size_x = self.micro_size_x micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k micro_size_k = self.micro_size_k
...@@ -286,10 +281,17 @@ class MatrixCoreIntrinEmitter: ...@@ -286,10 +281,17 @@ class MatrixCoreIntrinEmitter:
rk=0, rk=0,
): ):
tx, _, warp_m = self.extract_thread_binding(thread_binding) tx, _, warp_m = self.extract_thread_binding(thread_binding)
# {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx = (tx & 15)>>2
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if is_transposed: if is_transposed:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
row += warp_interval_idx * warp_row_interval
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else: else:
...@@ -299,30 +301,7 @@ class MatrixCoreIntrinEmitter: ...@@ -299,30 +301,7 @@ class MatrixCoreIntrinEmitter:
l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
# row_offset: base row offset for this warp
# Each warp processes 4 row groups (e.g., 0-3, 4-7, 8-11, 12-15 for 1 warp)
# warp_row_tiles = block_row_warps * 4 is the stride between row groups
# warp_m (0,1,2,3) determines the starting row group offset ?
# row_offset = warp_m * warp_row_tiles
# if is_transposed:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (rk * chunk + ki * (k_pack * micro_size_k), row_offset + i * warp_row_tiles + row)
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# else:
# for i in T.serial(warp_rows):
# for local_id in T.vectorized(k_pack * local_size_a):
# row, col = T.meta_var(reverse_index_map(tx, local_id))
# # l: M dimension address = base row offset + row group stride * i + micro row
# # row = tx % 16 gives position within the 16-row block
# l, r = (row_offset + i * warp_row_tiles + row, rk * chunk + ki * (k_pack * micro_size_k))
# A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l, A_base1 + r + col]
# return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0): def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0):
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
...@@ -351,9 +330,6 @@ class MatrixCoreIntrinEmitter: ...@@ -351,9 +330,6 @@ class MatrixCoreIntrinEmitter:
rk=0, rk=0,
): ):
tx, warp_n, _ = self.extract_thread_binding(thread_binding) tx, warp_n, _ = self.extract_thread_binding(thread_binding)
# col_offset: base column offset for this warp (similar to row_offset in ldmatrix_a)
# warp_col_tiles = block_col_warps * 4 is the stride between column groups
col_offset = warp_n * warp_col_tiles
if is_transposed: if is_transposed:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
...@@ -370,7 +346,7 @@ class MatrixCoreIntrinEmitter: ...@@ -370,7 +346,7 @@ class MatrixCoreIntrinEmitter:
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id)) row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
l, r = ( l, r = (
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment