Commit 8adfc117 authored by Yuxi Chi's avatar Yuxi Chi Committed by LeiWang1999
Browse files

[Bugfix] Fix get_swizzle_layout implementation. (#455)

* fix get_swizzle_layout implementation.

* format.
parent 0aaef97d
...@@ -89,25 +89,17 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j): ...@@ -89,25 +89,17 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16) return (i * 2 + j // 16, j % 16)
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None):
ana = arith.Analyzer() ana = arith.Analyzer()
BANK_SIZE_BYTES = 128
if isinstance(dtype, str): if isinstance(dtype, str):
dtype = DataType(dtype) dtype = DataType(dtype)
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( row_bytes = dtype.bits * row_size // 8
BANK_SIZE_BYTES // dtype.bits) assert row_bytes % 32 == 0, "Row size must be multiple of 32B."
# use transaction bits to support diverse dtype. if swizzle_bytes is None:
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits swizzle_bytes = min(128, row_bytes)
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits # 128B swizzle
coalescent_bits = dtype.bits * row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems = BANK_SIZE_BYTES // dtype.bits
new_col_idx_outer = None
if coalescent_bits % 1024 == 0:
# Use 8 * 8 permuted layout # Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read # Every number below corresponds to 16B
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
...@@ -116,33 +108,23 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): ...@@ -116,33 +108,23 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub = row_idx % bank_elems # 64B swizzle
new_col_idx_outer = col_idx_outer ^ row_idx_sub
else:
assert coalescent_bits % 512 == 0
# Use 8 * 4 permuted layout # Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read # Every number below corresponds to 16B
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub = row_idx % bank_elems # 32B swizzle
# Interleave elems per byte # Use 8 * 2 permuted layout
interleave_elems = 32 // dtype.bits # Every number below corresponds to 16B
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems) # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits" elem_per_16B = 128 // dtype.bits
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) col_idx_16B = col_idx // elem_per_16B
col_idx_in_16B = col_idx % elem_per_16B
new_col_idx_16B = col_idx_16B ^ (row_idx % (swizzle_bytes // 16))
return row_idx, ana.simplify(new_col_idx_16B * elem_per_16B + col_idx_in_16B)
def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False): def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment