Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8adfc117
Commit
8adfc117
authored
May 08, 2025
by
Yuxi Chi
Committed by
LeiWang1999
May 08, 2025
Browse files
[Bugfix] Fix get_swizzle_layout implementation. (#455)
* fix get_swizzle_layout implementation. * format.
parent
0aaef97d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
51 deletions
+33
-51
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+33
-51
No files found.
tilelang/intrinsics/mma_layout.py
View file @
8adfc117
...
@@ -89,25 +89,17 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
...
@@ -89,25 +89,17 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return
(
i
*
2
+
j
//
16
,
j
%
16
)
return
(
i
*
2
+
j
//
16
,
j
%
16
)
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]):
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]
,
swizzle_bytes
=
None
):
ana
=
arith
.
Analyzer
()
ana
=
arith
.
Analyzer
()
BANK_SIZE_BYTES
=
128
if
isinstance
(
dtype
,
str
):
if
isinstance
(
dtype
,
str
):
dtype
=
DataType
(
dtype
)
dtype
=
DataType
(
dtype
)
col_idx_outer
,
col_idx_inner
=
col_idx
//
(
BANK_SIZE_BYTES
//
dtype
.
bits
),
col_idx
%
(
row_bytes
=
dtype
.
bits
*
row_size
//
8
BANK_SIZE_BYTES
//
dtype
.
bits
)
assert
row_bytes
%
32
==
0
,
"Row size must be multiple of 32B."
# use transaction bits to support diverse dtype.
if
swizzle_bytes
is
None
:
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
swizzle_bytes
=
min
(
128
,
row_bytes
)
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
# 128B swizzle
coalescent_bits
=
dtype
.
bits
*
row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems
=
BANK_SIZE_BYTES
//
dtype
.
bits
new_col_idx_outer
=
None
if
coalescent_bits
%
1024
==
0
:
# Use 8 * 8 permuted layout
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every number below corresponds to 16B
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
...
@@ -116,33 +108,23 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
...
@@ -116,33 +108,23 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
# 64B swizzle
new_col_idx_outer
=
col_idx_outer
^
row_idx_sub
else
:
assert
coalescent_bits
%
512
==
0
# Use 8 * 4 permuted layout
# Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every number below corresponds to 16B
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
# 32B swizzle
# Interleave elems per byte
# Use 8 * 2 permuted layout
interleave_elems
=
32
//
dtype
.
bits
# Every number below corresponds to 16B
new_col_idx_outer
=
col_idx_outer
^
(
row_idx_sub
//
interleave_elems
)
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
assert
(
new_col_idx_outer
is
not
None
),
f
"Unsupported dtype
{
dtype
}
with
{
coalescent_bits
}
bits"
elem_per_16B
=
128
//
dtype
.
bits
return
row_idx
,
ana
.
simplify
(
new_col_idx_outer
*
bank_elems
+
col_idx_inner
)
col_idx_16B
=
col_idx
//
elem_per_16B
col_idx_in_16B
=
col_idx
%
elem_per_16B
new_col_idx_16B
=
col_idx_16B
^
(
row_idx
%
(
swizzle_bytes
//
16
))
return
row_idx
,
ana
.
simplify
(
new_col_idx_16B
*
elem_per_16B
+
col_idx_in_16B
)
def
make_mma_swizzle_layout
(
shared_buf
,
is_smooth
:
bool
=
False
):
def
make_mma_swizzle_layout
(
shared_buf
,
is_smooth
:
bool
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment