Commit 8adfc117 authored by Yuxi Chi's avatar Yuxi Chi Committed by LeiWang1999
Browse files

[Bugfix] Fix get_swizzle_layout implementation. (#455)

* fix get_swizzle_layout implementation.

* format.
parent 0aaef97d
...@@ -89,60 +89,42 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j): ...@@ -89,60 +89,42 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16) return (i * 2 + j // 16, j % 16)
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None):
ana = arith.Analyzer() ana = arith.Analyzer()
BANK_SIZE_BYTES = 128
if isinstance(dtype, str): if isinstance(dtype, str):
dtype = DataType(dtype) dtype = DataType(dtype)
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( row_bytes = dtype.bits * row_size // 8
BANK_SIZE_BYTES // dtype.bits) assert row_bytes % 32 == 0, "Row size must be multiple of 32B."
# use transaction bits to support diverse dtype. if swizzle_bytes is None:
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits swizzle_bytes = min(128, row_bytes)
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits # 128B swizzle
coalescent_bits = dtype.bits * row_size # Use 8 * 8 permuted layout
# permutation on 4 banks, each bank has 32 bits # Every number below corresponds to 16B
bank_elems = BANK_SIZE_BYTES // dtype.bits # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
new_col_idx_outer = None # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
if coalescent_bits % 1024 == 0: # 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# Use 8 * 8 permuted layout # 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# Every row below corresponds to 32 banks # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 # 64B swizzle
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 # Use 8 * 4 permuted layout
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 # Every number below corresponds to 16B
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub = row_idx % bank_elems # 32B swizzle
new_col_idx_outer = col_idx_outer ^ row_idx_sub # Use 8 * 2 permuted layout
else: # Every number below corresponds to 16B
assert coalescent_bits % 512 == 0 # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# Use 8 * 4 permuted layout # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read elem_per_16B = 128 // dtype.bits
# Every row below corresponds to 16 banks col_idx_16B = col_idx // elem_per_16B
# 0 1 2 3 ==> 0 1 2 3 col_idx_in_16B = col_idx % elem_per_16B
# 0 1 2 3 ==> 0 1 2 3 new_col_idx_16B = col_idx_16B ^ (row_idx % (swizzle_bytes // 16))
# 0 1 2 3 ==> 1 0 3 2 return row_idx, ana.simplify(new_col_idx_16B * elem_per_16B + col_idx_in_16B)
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub = row_idx % bank_elems
# Interleave elems per byte
interleave_elems = 32 // dtype.bits
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems)
assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits"
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)
def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False): def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment