Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8adfc117
Commit
8adfc117
authored
May 08, 2025
by
Yuxi Chi
Committed by
LeiWang1999
May 08, 2025
Browse files
[Bugfix] Fix get_swizzle_layout implementation. (#455)
* fix get_swizzle_layout implementation. * format.
parent
0aaef97d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
51 deletions
+33
-51
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+33
-51
No files found.
tilelang/intrinsics/mma_layout.py
View file @
8adfc117
...
...
@@ -89,60 +89,42 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return
(
i
*
2
+
j
//
16
,
j
%
16
)
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]):
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]
,
swizzle_bytes
=
None
):
ana
=
arith
.
Analyzer
()
BANK_SIZE_BYTES
=
128
if
isinstance
(
dtype
,
str
):
dtype
=
DataType
(
dtype
)
col_idx_outer
,
col_idx_inner
=
col_idx
//
(
BANK_SIZE_BYTES
//
dtype
.
bits
),
col_idx
%
(
BANK_SIZE_BYTES
//
dtype
.
bits
)
# use transaction bits to support diverse dtype.
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
coalescent_bits
=
dtype
.
bits
*
row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems
=
BANK_SIZE_BYTES
//
dtype
.
bits
new_col_idx_outer
=
None
if
coalescent_bits
%
1024
==
0
:
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
new_col_idx_outer
=
col_idx_outer
^
row_idx_sub
else
:
assert
coalescent_bits
%
512
==
0
# Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
# Interleave elems per byte
interleave_elems
=
32
//
dtype
.
bits
new_col_idx_outer
=
col_idx_outer
^
(
row_idx_sub
//
interleave_elems
)
assert
(
new_col_idx_outer
is
not
None
),
f
"Unsupported dtype
{
dtype
}
with
{
coalescent_bits
}
bits"
return
row_idx
,
ana
.
simplify
(
new_col_idx_outer
*
bank_elems
+
col_idx_inner
)
row_bytes
=
dtype
.
bits
*
row_size
//
8
assert
row_bytes
%
32
==
0
,
"Row size must be multiple of 32B."
if
swizzle_bytes
is
None
:
swizzle_bytes
=
min
(
128
,
row_bytes
)
# 128B swizzle
# Use 8 * 8 permuted layout
# Every number below corresponds to 16B
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
# 64B swizzle
# Use 8 * 4 permuted layout
# Every number below corresponds to 16B
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
# 32B swizzle
# Use 8 * 2 permuted layout
# Every number below corresponds to 16B
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
elem_per_16B
=
128
//
dtype
.
bits
col_idx_16B
=
col_idx
//
elem_per_16B
col_idx_in_16B
=
col_idx
%
elem_per_16B
new_col_idx_16B
=
col_idx_16B
^
(
row_idx
%
(
swizzle_bytes
//
16
))
return
row_idx
,
ana
.
simplify
(
new_col_idx_16B
*
elem_per_16B
+
col_idx_in_16B
)
def
make_mma_swizzle_layout
(
shared_buf
,
is_smooth
:
bool
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment