Commit bb2f5e4f authored by wangziyang's avatar wangziyang
Browse files

[Bugfix] A share to local padding

parent 41887aed
...@@ -21,6 +21,20 @@ import tilelang.language as T ...@@ -21,6 +21,20 @@ import tilelang.language as T
3 19 35 51 3 19 35 51
''' '''
def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps, tx):
new_col = col
if warp_rows > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = inter_idx_padding * block_row_warps * 4
print("paddings:", paddings)
new_row = row + paddings * ((tx & 15) // 4)
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * block_row_warps
return new_row, new_col
def shared_16x16_to_local_64x4_layout_A(i, j): def shared_16x16_to_local_64x4_layout_A(i, j):
# i: row (0-15), j: column (0-15) # i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position # thread_id: which thread handles this position
......
...@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment ...@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
from .mfma_layout import ( from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A, shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B, shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A, # shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B, shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A, shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B, shared_16x32_to_local_64x8_layout_B,
...@@ -19,7 +19,7 @@ from .mfma_layout import ( ...@@ -19,7 +19,7 @@ from .mfma_layout import (
shared_16x64_to_local_64x16_layout_B, shared_16x64_to_local_64x16_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A, thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B, thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A, # thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B, thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A, thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B, thread_id_shared_access_64x8_to_16x32_layout_B,
...@@ -27,10 +27,11 @@ from .mfma_layout import ( ...@@ -27,10 +27,11 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B, thread_id_shared_access_64x16_to_16x64_layout_B,
) )
# from .mmac_layout import ( from .mmac_layout import (
# shared_16x16_to_local_64x4_layout_A, shared_16x16_to_local_64x4_layout_A,
# thread_id_shared_access_64x4_to_16x16_layout_A, thread_id_shared_access_64x4_to_16x16_layout_A,
# ) shared_to_local_layout_A_padding
)
lift = convert lift = convert
...@@ -251,6 +252,7 @@ class MatrixCoreIntrinEmitter: ...@@ -251,6 +252,7 @@ class MatrixCoreIntrinEmitter:
) )
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
# share mem a needs warp number # share mem a needs warp number
warp_num = self.block_row_warps warp_num = self.block_row_warps
...@@ -272,6 +274,7 @@ class MatrixCoreIntrinEmitter: ...@@ -272,6 +274,7 @@ class MatrixCoreIntrinEmitter:
A_buf = A_region.buffer A_buf = A_region.buffer
A_base0 = A_region.region[-2].min A_base0 = A_region.region[-2].min
A_base1 = A_region.region[-1].min A_base1 = A_region.region[-1].min
print("A_base0, A_base1:", A_base0, A_base1)
@T.macro @T.macro
def _warp_ldmatrix_a( def _warp_ldmatrix_a(
...@@ -281,31 +284,43 @@ class MatrixCoreIntrinEmitter: ...@@ -281,31 +284,43 @@ class MatrixCoreIntrinEmitter:
thread_binding, thread_binding,
rk=0, rk=0,
): ):
tx, _, warp_m = self.extract_thread_binding(thread_binding) # warp_n[0-256] -> {0,1,2,3}
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
# {0..3,16..19,32..35,48..51} -> 0 # {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1 # {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2 # {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3 # {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx = (tx & 15)>>2 warp_interval_idx = (tx & 15)>>2
warp_group_idx = (tx // 32) warp_group_idx = warp_n
warp_inter_stride = 4
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块 # warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if is_transposed: if is_transposed:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
# 每轮初始位置行偏移 # 每轮初始位置行偏移
row += i * warp_row_init # row += i * warp_row_init
# warp 组行间隔 # # warp 组行间隔
row += warp_group_idx * 4 # row += warp_group_idx * 4
# warp 内行间隔 # # warp 内行间隔
row += warp_interval_idx * warp_row_interval # row += warp_interval_idx * warp_row_interval
raise NotImplementedError("Transposed A with preshuffle is not implemented yet")
row, col = shared_to_local_layout_A_padding(row, col, i, warp_rows, self.block_row_warps, tx)
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) # # 每轮初始位置行偏移
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
row, col = shared_to_local_layout_A_padding(row, col, i, warp_rows, self.block_row_warps, tx)
print("row, col:", row, col)
l, r = (warp_m * warp_inter_stride, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......
...@@ -32,8 +32,8 @@ class GemmMMAC(GemmBase): ...@@ -32,8 +32,8 @@ class GemmMMAC(GemmBase):
if self.is_gemm_ss(): if self.is_gemm_ss():
return { return {
self.A: make_swizzled_layout(self.A), self.A: make_linear_layout(self.A),
self.B: make_swizzled_layout(self.B), self.B: make_linear_layout(self.B),
self.C: mmac_emitter.make_mmac_store_layout(self.C), self.C: mmac_emitter.make_mmac_store_layout(self.C),
} }
elif self.is_gemm_sr(): elif self.is_gemm_sr():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment