Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8adfc117
Commit
8adfc117
authored
May 08, 2025
by
Yuxi Chi
Committed by
LeiWang1999
May 08, 2025
Browse files
[Bugfix] Fix get_swizzle_layout implementation. (#455)
* fix get_swizzle_layout implementation. * format.
parent
0aaef97d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
51 deletions
+33
-51
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+33
-51
No files found.
tilelang/intrinsics/mma_layout.py
View file @
8adfc117
...
@@ -89,25 +89,17 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
...
@@ -89,25 +89,17 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return
(
i
*
2
+
j
//
16
,
j
%
16
)
return
(
i
*
2
+
j
//
16
,
j
%
16
)
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]):
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]
,
swizzle_bytes
=
None
):
ana
=
arith
.
Analyzer
()
ana
=
arith
.
Analyzer
()
BANK_SIZE_BYTES
=
128
if
isinstance
(
dtype
,
str
):
if
isinstance
(
dtype
,
str
):
dtype
=
DataType
(
dtype
)
dtype
=
DataType
(
dtype
)
col_idx_outer
,
col_idx_inner
=
col_idx
//
(
BANK_SIZE_BYTES
//
dtype
.
bits
),
col_idx
%
(
row_bytes
=
dtype
.
bits
*
row_size
//
8
BANK_SIZE_BYTES
//
dtype
.
bits
)
assert
row_bytes
%
32
==
0
,
"Row size must be multiple of 32B."
# use transaction bits to support diverse dtype.
if
swizzle_bytes
is
None
:
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
swizzle_bytes
=
min
(
128
,
row_bytes
)
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
# 128B swizzle
coalescent_bits
=
dtype
.
bits
*
row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems
=
BANK_SIZE_BYTES
//
dtype
.
bits
new_col_idx_outer
=
None
if
coalescent_bits
%
1024
==
0
:
# Use 8 * 8 permuted layout
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every number below corresponds to 16B
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
...
@@ -116,33 +108,23 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
...
@@ -116,33 +108,23 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
# 64B swizzle
new_col_idx_outer
=
col_idx_outer
^
row_idx_sub
else
:
assert
coalescent_bits
%
512
==
0
# Use 8 * 4 permuted layout
# Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every number below corresponds to 16B
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
# 32B swizzle
# Interleave elems per byte
# Use 8 * 2 permuted layout
interleave_elems
=
32
//
dtype
.
bits
# Every number below corresponds to 16B
new_col_idx_outer
=
col_idx_outer
^
(
row_idx_sub
//
interleave_elems
)
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
assert
(
new_col_idx_outer
is
not
None
),
f
"Unsupported dtype
{
dtype
}
with
{
coalescent_bits
}
bits"
elem_per_16B
=
128
//
dtype
.
bits
return
row_idx
,
ana
.
simplify
(
new_col_idx_outer
*
bank_elems
+
col_idx_inner
)
col_idx_16B
=
col_idx
//
elem_per_16B
col_idx_in_16B
=
col_idx
%
elem_per_16B
new_col_idx_16B
=
col_idx_16B
^
(
row_idx
%
(
swizzle_bytes
//
16
))
return
row_idx
,
ana
.
simplify
(
new_col_idx_16B
*
elem_per_16B
+
col_idx_in_16B
)
def
make_mma_swizzle_layout
(
shared_buf
,
is_smooth
:
bool
=
False
):
def
make_mma_swizzle_layout
(
shared_buf
,
is_smooth
:
bool
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment