Commit b8213492 authored by wangziyang's avatar wangziyang
Browse files

add B mmac layout

parent bb2f5e4f
......@@ -233,7 +233,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
# Transform B_local layout from shared memory thread-interleaved to local row-major
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
# mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
......@@ -275,7 +275,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
# Inject ds_read for shared to register memory copy on DCU
mod = tilelang.transform.InjectDSRead()(mod)
# mod = tilelang.transform.InjectDSRead()(mod)
print("222222222")
print(mod)
......
......@@ -35,6 +35,21 @@ def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps,
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * block_row_warps
return new_row, new_col
def shared_to_local_layout_B_padding(row, col, idx, warp_cols, block_col_warps, tx):
# ???
# T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ...
# T0 T1 T2
new_col = ((col % 16 * 8 ) // 32 * 8) + (col // 16) * 4 + (row // 4)
# swap per 4 row idx1 <-> idx2
bit0 = (tx // 4) & 1
bit1 = (tx // 2) & 1
new_row = (tx // 4) & (~3) + (bit0 * 2) + bit1
paddings = 32
new_col += (idx & 1) * paddings
return new_row, new_col
def shared_16x16_to_local_64x4_layout_A(i, j):
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
......@@ -67,5 +82,28 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(tx, local_id):
j = (tx // 16) * 4 + local_id # 0-15: column position
return i, j
def shared_32x16_to_local_64x8_layout_B(i, j):
# i: row (0-15) j : col (0-31)
# Layout:
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# row=0-3 T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15
# row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31
# row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47
# row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63
thread_id = 16 * (i // 4)
local_id = (j // 16) * 4 + (i % 4)
return thread_id, local_id
def thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id):
# tx: thread id within warp (0-63)
# local_id: element index within thread's 4-element vector (0-7)
# i : row (0-15) j : col (0-31)
i = (tx // 15) * 4 + local_id % 4
j = (tx % 15) + (local_id // 4) * 16
return i,j
shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
shared_32x16_to_local_64x8_layout_n_m = shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_k_n = shared_32x16_to_local_64x8_layout_B
\ No newline at end of file
......@@ -30,7 +30,10 @@ from .mfma_layout import (
from .mmac_layout import (
shared_16x16_to_local_64x4_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_A,
shared_to_local_layout_A_padding
shared_to_local_layout_A_padding,
shared_32x16_to_local_64x8_layout_B,
thread_id_shared_access_64x8_to_32x16_layout_B,
shared_to_local_layout_B_padding
)
lift = convert
......@@ -295,6 +298,7 @@ class MatrixCoreIntrinEmitter:
warp_inter_stride = 4
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if is_transposed:
raise NotImplementedError("Transposed A is not implemented yet")
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
......@@ -304,7 +308,6 @@ class MatrixCoreIntrinEmitter:
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
raise NotImplementedError("Transposed A with preshuffle is not implemented yet")
row, col = shared_to_local_layout_A_padding(row, col, i, warp_rows, self.block_row_warps, tx)
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
......@@ -330,6 +333,7 @@ class MatrixCoreIntrinEmitter:
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
k_pack = self.k_pack
......@@ -351,8 +355,10 @@ class MatrixCoreIntrinEmitter:
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
warp_inter_stride = 32
if is_transposed:
raise NotImplementedError("Transposed B is not implemented yet")
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
......@@ -366,9 +372,11 @@ class MatrixCoreIntrinEmitter:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
row, col = shared_to_local_layout_B_padding(row, col, j, warp_cols, self.block_col_warps, tx)
l, r = (
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
# warp_n * warp_col_tiles + j * micro_size_y,
warp_n * warp_inter_stride,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment