Commit 64179eaf authored by qisan's avatar qisan
Browse files

Merge remote ds_read: resolve conflicts in phase.py, mmac_macro_generator.py, gemm_mmac.py

parents dd91b1e0 ad92620d
......@@ -236,7 +236,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
# Transform B_local layout from shared memory thread-interleaved to local row-major
mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
# mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
......@@ -282,9 +282,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject ds_read for shared to register memory copy on DCU
# mod = tilelang.transform.InjectDSRead()(mod)
# mod = tilelang.transform.InjectDSRead()(mod)
print("222222222")
print(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
......
from tvm import DataType
from tvm.runtime import convert
# from tvm import DataType
# from tvm.runtime import convert
import tilelang.language as T
......@@ -21,6 +21,44 @@ import tilelang.language as T
3 19 35 51
'''
def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps, tx):
new_col = col
if warp_rows > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = inter_idx_padding * block_row_warps * 4
# print("paddings:", paddings)
new_row = row + paddings * ((tx & 15) // 4)
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * block_row_warps
return new_row, new_col
def shared_to_local_layout_B_padding(row, col, idx, warp_cols, block_col_warps, tx):
# ???
# T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ...
# T0 T1 T2
new_col = ((tx % 4) * 8) + (col // 16) * 4 + (row % 4)
# new_col = ((col % 16 * 8 ) // 32 * 8) + (col // 16) * 4 + (row // 4)
# swap per 4 row idx1 <-> idx2
row_tmp = tx // 4
bit0 = row_tmp & 1
bit1 = (row_tmp >> 1) & 1
new_row = ((row_tmp // 4) << 2) + (bit0 * 2) + bit1
# return new_row, new_col
if block_col_warps > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = 32
new_col += idx * inter_idx_padding * paddings
return new_row, new_col
def shared_16x16_to_local_64x4_layout_A(i, j):
# i: row (0-15), j: column (0-15)
# thread_id: which thread handles this position
......@@ -53,5 +91,108 @@ def thread_id_shared_access_64x4_to_16x16_layout_A(tx, local_id):
j = (tx // 16) * 4 + local_id # 0-15: column position
return i, j
def shared_32x16_to_local_64x8_layout_B(i, j):
# i: row (0-15) j : col (0-31)
# Layout:
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# row=0-3 T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15
# row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31
# row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47
# row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63
thread_id = 16 * (i // 4) + j % 16
local_id = (j // 16) * 4 + (i % 4)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
return thread_id, local_id
def thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id):
# tx: thread id within warp (0-63)
# local_id: element index within thread's 4-element vector (0-7)
# i : row (0-15) j : col (0-31)
i = (tx // 15) * 4 + local_id % 4
j = (tx % 15) + (local_id // 4) * 16
return i,j
shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
\ No newline at end of file
shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
shared_32x16_to_local_64x8_layout_n_m = shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_k_n = shared_32x16_to_local_64x8_layout_B
if __name__ == "__main__":
idx = 1
warp_cols = 1
block_col_warps = 1
row = 16
col = 32
tx_f = 0
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for i in range(row):
# print("{:^2}".format(i),end=" ")
# for j in range(col):
# thread_id,local_id = shared_32x16_to_local_64x8_layout_B(i,j)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for j in range(row):
# local_idx = j % 4
# print("{:^2}".format(j),end=" ")
# for i in range(col):
# tx = (j // 4) * 16 + (i % 16)
# local_id = (i // 16) * 4 + local_idx
# row_idx,col_idx = thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id)
# print("T{:^2} V{:^2} ".format(tx,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
tx = 16
# row_ = 4
col_ = 16
group = 4
results = []
for k in range(group):
for j in range(col_):
tx = j + k * 16
for i in range(k * 4,(k+1) * 4):
new_row,new_col = shared_to_local_layout_B_padding(i,j,idx,warp_cols,block_col_warps,tx)
results.append((new_row,new_col,(tx,i-k*4)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
for i in range(k * 4,(k+1) * 4):
new_row,new_col = shared_to_local_layout_B_padding(i,j+16,idx,warp_cols,block_col_warps,tx)
results.append((new_row,new_col,(tx,i-k*4+4)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
# 转换为嵌套字典
grid = {}
for r, c, t in results:
if r not in grid:
grid[r] = {}
grid[r][c] = t
print("layout B padding ...")
print(" ",end=" ")
col = 32
for i in range(col):
print("{:^8}".format(i),end=" ")
print()
# 确定行遍历范围
rows = sorted(grid.keys())
# 按照 行范围 -> 列范围 遍历
for r in rows:
print("{:^2}".format(r),end=" ")
#获取当前行下列号并排序
cols = sorted(grid[r].keys())
for c in cols:
tx,id = grid[r][c]
print("T{:^2} V{:^2} ".format(tx,id),end=" ")
print()
\ No newline at end of file
......@@ -30,6 +30,10 @@ from .mfma_layout import (
from .mmac_layout import (
shared_16x16_to_local_64x4_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_A,
shared_to_local_layout_A_padding,
shared_32x16_to_local_64x8_layout_B,
thread_id_shared_access_64x8_to_32x16_layout_B,
shared_to_local_layout_B_padding
)
lift = convert
......@@ -250,20 +254,6 @@ class MatrixCoreIntrinEmitter:
(thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m
def map_64x16(self, row, col, idx, warp_rows, tx):
new_col = col
if warp_rows > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = inter_idx_padding * self.block_row_warps * 4
print("paddings:", paddings)
new_row = row + paddings * ((tx & 15) // 4)
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * self.block_row_warps
return new_row, new_col
def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0):
......@@ -305,8 +295,10 @@ class MatrixCoreIntrinEmitter:
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx = (tx & 15)>>2
warp_group_idx = warp_n
warp_inter_stride = 4
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if is_transposed:
raise NotImplementedError("Transposed A is not implemented yet")
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
......@@ -316,8 +308,7 @@ class MatrixCoreIntrinEmitter:
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
raise NotImplementedError("Transposed A with preshuffle is not implemented yet")
row, col = self.map_64x16(row, col, i, warp_rows, tx)
row, col = shared_to_local_layout_A_padding(row, col, i, warp_rows, self.block_row_warps, tx)
l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
else:
......@@ -330,9 +321,9 @@ class MatrixCoreIntrinEmitter:
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
row, col = self.map_64x16(row, col, i, warp_rows, tx)
row, col = shared_to_local_layout_A_padding(row, col, i, warp_rows, self.block_row_warps, tx)
print("row, col:", row, col)
l, r = (warp_m * 4, rk * chunk + ki * (k_pack * micro_size_k))
l, r = (warp_m * warp_inter_stride, rk * chunk + ki * (k_pack * micro_size_k))
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
......@@ -342,6 +333,7 @@ class MatrixCoreIntrinEmitter:
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
k_pack = self.k_pack
......@@ -363,8 +355,10 @@ class MatrixCoreIntrinEmitter:
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
warp_inter_stride = 32
if is_transposed:
raise NotImplementedError("Transposed B is not implemented yet")
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
......@@ -378,9 +372,11 @@ class MatrixCoreIntrinEmitter:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
row, col = shared_to_local_layout_B_padding(row, col, j, warp_cols, self.block_col_warps, tx)
l, r = (
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
# warp_n * warp_col_tiles + j * micro_size_y,
warp_n * warp_inter_stride,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col]
......
......@@ -32,8 +32,6 @@ class GemmMMAC(GemmBase):
if self.is_gemm_ss():
return {
# self.A: make_swizzled_layout(self.A, allow_pad=False),
# self.B: make_swizzled_layout(self.B, allow_pad=False),
self.A: make_linear_layout(self.A),
self.B: make_linear_layout(self.B),
self.C: mmac_emitter.make_mmac_store_layout(self.C),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment