Commit bb2f5e4f authored by wangziyang's avatar wangziyang
Browse files

[Bugfix] A share to local padding

parent 41887aed
......@@ -21,6 +21,20 @@ import tilelang.language as T
3 19 35 51
'''
def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps, tx):
new_col = col
if warp_rows > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = inter_idx_padding * block_row_warps * 4
print("paddings:", paddings)
new_row = row + paddings * ((tx & 15) // 4)
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * block_row_warps
return new_row, new_col
def shared_16x16_to_local_64x4_layout_A(i, j):
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
......
......@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
# shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
......@@ -19,7 +19,7 @@ from .mfma_layout import (
shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
......@@ -27,10 +27,11 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B,
)
# from .mmac_layout import (
# shared_16x16_to_local_64x4_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A,
# )
from .mmac_layout import (
shared_16x16_to_local_64x4_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_A,
shared_to_local_layout_A_padding
)
lift = convert
......@@ -250,7 +251,8 @@ class MatrixCoreIntrinEmitter:
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
# share mem a needs warp number
warp_num = self.block_row_warps
......@@ -272,6 +274,7 @@ class MatrixCoreIntrinEmitter:
A_buf = A_region.buffer
A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min
print("A_base0, A_base1:", A_base0, A_base1)
@T.macro
def _warp_ldmatrix_a(
......@@ -281,31 +284,43 @@ class MatrixCoreIntrinEmitter:
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
# warp_n[0-256] -> {0,1,2,3}
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
# {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx = (tx & 15)>>2
warp_group_idx = (tx // 32)
warp_group_idx = warp_n
warp_inter_stride = 4
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
# 每轮初始位置行偏移
row += i * warp_row_init
# warp 组行间隔
row += warp_group_idx * 4
# warp 内行间隔
row += warp_interval_idx * warp_row_interval
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
raise NotImplementedError("Transposed A with preshuffle is not implemented yet")
row, col = shared_to_local_layout_A_padding(row, col, i, warp_rows, self.block_row_warps, tx)
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
# # 每轮初始位置行偏移
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
row, col = shared_to_local_layout_A_padding(row, col, i, warp_rows, self.block_row_warps, tx)
print("row, col:", row, col)
l, r = (warp_m * warp_inter_stride, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......
......@@ -32,8 +32,8 @@ class GemmMMAC(GemmBase):
if self.is_gemm_ss():
return {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.A: make_linear_layout(self.A),
self.B: make_linear_layout(self.B),
self.C: mmac_emitter.make_mmac_store_layout(self.C),
}
elif self.is_gemm_sr():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment