Commit ad92620d authored by wangziyang's avatar wangziyang
Browse files

print B warp layout & B s_to_l padding

parent b8213492
from tvm import DataType # from tvm import DataType
from tvm.runtime import convert # from tvm.runtime import convert
import tilelang.language as T import tilelang.language as T
...@@ -29,7 +29,7 @@ def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps, ...@@ -29,7 +29,7 @@ def shared_to_local_layout_A_padding(row, col, idx, warp_rows, block_row_warps,
inter_idx_padding = 1 inter_idx_padding = 1
paddings = inter_idx_padding * block_row_warps * 4 paddings = inter_idx_padding * block_row_warps * 4
print("paddings:", paddings) # print("paddings:", paddings)
new_row = row + paddings * ((tx & 15) // 4) new_row = row + paddings * ((tx & 15) // 4)
new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * block_row_warps new_row += (idx & 1) * (paddings // 2) + (idx // 2) * 16 * 2 * block_row_warps
...@@ -39,15 +39,24 @@ def shared_to_local_layout_B_padding(row, col, idx, warp_cols, block_col_warps, ...@@ -39,15 +39,24 @@ def shared_to_local_layout_B_padding(row, col, idx, warp_cols, block_col_warps,
# ??? # ???
# T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ... # T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15 T0 T1 ...
# T0 T1 T2 # T0 T1 T2
new_col = ((col % 16 * 8 ) // 32 * 8) + (col // 16) * 4 + (row // 4) new_col = ((tx % 4) * 8) + (col // 16) * 4 + (row % 4)
# new_col = ((col % 16 * 8 ) // 32 * 8) + (col // 16) * 4 + (row // 4)
# swap per 4 row idx1 <-> idx2 # swap per 4 row idx1 <-> idx2
bit0 = (tx // 4) & 1 row_tmp = tx // 4
bit1 = (tx // 2) & 1 bit0 = row_tmp & 1
new_row = (tx // 4) & (~3) + (bit0 * 2) + bit1 bit1 = (row_tmp >> 1) & 1
new_row = ((row_tmp // 4) << 2) + (bit0 * 2) + bit1
# return new_row, new_col
if block_col_warps > 1:
inter_idx_padding = 2
else:
inter_idx_padding = 1
paddings = 32 paddings = 32
new_col += (idx & 1) * paddings new_col += idx * inter_idx_padding * paddings
return new_row, new_col return new_row, new_col
def shared_16x16_to_local_64x4_layout_A(i, j): def shared_16x16_to_local_64x4_layout_A(i, j):
...@@ -90,8 +99,9 @@ def shared_32x16_to_local_64x8_layout_B(i, j): ...@@ -90,8 +99,9 @@ def shared_32x16_to_local_64x8_layout_B(i, j):
# row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31 # row=4-7 T16 T17 T18 T19 T20 T21 T22 T23 T24 T25 T26 T27 T28 T29 T30 T31
# row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47 # row=8-11 T32 T33 T34 T35 T36 T37 T38 T39 T40 T41 T42 T43 T44 T45 T46 T47
# row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63 # row=12-15 T48 T49 T50 T51 T52 T53 T54 T55 T56 T57 T58 T59 T60 T61 T62 T63
thread_id = 16 * (i // 4) thread_id = 16 * (i // 4) + j % 16
local_id = (j // 16) * 4 + (i % 4) local_id = (j // 16) * 4 + (i % 4)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
return thread_id, local_id return thread_id, local_id
...@@ -106,4 +116,83 @@ def thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id): ...@@ -106,4 +116,83 @@ def thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id):
shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A shared_16x16_to_local_64x4_layout_m_n = shared_16x16_to_local_64x4_layout_A
shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A shared_16x16_to_local_64x4_layout_n_k = shared_16x16_to_local_64x4_layout_A
shared_32x16_to_local_64x8_layout_n_m = shared_32x16_to_local_64x8_layout_B shared_32x16_to_local_64x8_layout_n_m = shared_32x16_to_local_64x8_layout_B
shared_32x16_to_local_64x8_layout_k_n = shared_32x16_to_local_64x8_layout_B shared_32x16_to_local_64x8_layout_k_n = shared_32x16_to_local_64x8_layout_B
\ No newline at end of file
if __name__ == "__main__":
idx = 1
warp_cols = 1
block_col_warps = 1
row = 16
col = 32
tx_f = 0
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for i in range(row):
# print("{:^2}".format(i),end=" ")
# for j in range(col):
# thread_id,local_id = shared_32x16_to_local_64x8_layout_B(i,j)
# print("T{:^2} V{:^2} ".format(thread_id,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
# for j in range(row):
# local_idx = j % 4
# print("{:^2}".format(j),end=" ")
# for i in range(col):
# tx = (j // 4) * 16 + (i % 16)
# local_id = (i // 16) * 4 + local_idx
# row_idx,col_idx = thread_id_shared_access_64x8_to_32x16_layout_B(tx, local_id)
# print("T{:^2} V{:^2} ".format(tx,local_id),end=" ")
# print()
# print(" ",end=" ")
# for i in range(col):
# print("{:^8}".format(i),end=" ")
# print()
tx = 16
# row_ = 4
col_ = 16
group = 4
results = []
for k in range(group):
for j in range(col_):
tx = j + k * 16
for i in range(k * 4,(k+1) * 4):
new_row,new_col = shared_to_local_layout_B_padding(i,j,idx,warp_cols,block_col_warps,tx)
results.append((new_row,new_col,(tx,i-k*4)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
for i in range(k * 4,(k+1) * 4):
new_row,new_col = shared_to_local_layout_B_padding(i,j+16,idx,warp_cols,block_col_warps,tx)
results.append((new_row,new_col,(tx,i-k*4+4)))
# print("({:^2},{:^2}) {:^2}".format(new_row,new_col,tx))
# 转换为嵌套字典
grid = {}
for r, c, t in results:
if r not in grid:
grid[r] = {}
grid[r][c] = t
print("layout B padding ...")
print(" ",end=" ")
col = 32
for i in range(col):
print("{:^8}".format(i),end=" ")
print()
# 确定行遍历范围
rows = sorted(grid.keys())
# 按照 行范围 -> 列范围 遍历
for r in rows:
print("{:^2}".format(r),end=" ")
#获取当前行下列号并排序
cols = sorted(grid[r].keys())
for c in cols:
tx,id = grid[r][c]
print("T{:^2} V{:^2} ".format(tx,id),end=" ")
print()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment