Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8adfc117
Commit
8adfc117
authored
May 08, 2025
by
Yuxi Chi
Committed by
LeiWang1999
May 08, 2025
Browse files
[Bugfix] Fix get_swizzle_layout implementation. (#455)
* fix get_swizzle_layout implementation. * format.
parent
0aaef97d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
51 deletions
+33
-51
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+33
-51
No files found.
tilelang/intrinsics/mma_layout.py
View file @
8adfc117
...
@@ -89,60 +89,42 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
...
@@ -89,60 +89,42 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return
(
i
*
2
+
j
//
16
,
j
%
16
)
return
(
i
*
2
+
j
//
16
,
j
%
16
)
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]):
def
get_swizzle_layout
(
row_idx
,
col_idx
,
row_size
,
dtype
:
Union
[
DataType
,
str
]
,
swizzle_bytes
=
None
):
ana
=
arith
.
Analyzer
()
ana
=
arith
.
Analyzer
()
BANK_SIZE_BYTES
=
128
if
isinstance
(
dtype
,
str
):
if
isinstance
(
dtype
,
str
):
dtype
=
DataType
(
dtype
)
dtype
=
DataType
(
dtype
)
col_idx_outer
,
col_idx_inner
=
col_idx
//
(
BANK_SIZE_BYTES
//
dtype
.
bits
),
col_idx
%
(
row_bytes
=
dtype
.
bits
*
row_size
//
8
BANK_SIZE_BYTES
//
dtype
.
bits
)
assert
row_bytes
%
32
==
0
,
"Row size must be multiple of 32B."
# use transaction bits to support diverse dtype.
if
swizzle_bytes
is
None
:
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
swizzle_bytes
=
min
(
128
,
row_bytes
)
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
# 128B swizzle
coalescent_bits
=
dtype
.
bits
*
row_size
# Use 8 * 8 permuted layout
# permutation on 4 banks, each bank has 32 bits
# Every number below corresponds to 16B
bank_elems
=
BANK_SIZE_BYTES
//
dtype
.
bits
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
new_col_idx_outer
=
None
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
if
coalescent_bits
%
1024
==
0
:
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# Use 8 * 8 permuted layout
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 64B swizzle
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# Use 8 * 4 permuted layout
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# Every number below corresponds to 16B
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
# 32B swizzle
new_col_idx_outer
=
col_idx_outer
^
row_idx_sub
# Use 8 * 2 permuted layout
else
:
# Every number below corresponds to 16B
assert
coalescent_bits
%
512
==
0
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# Use 8 * 4 permuted layout
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
elem_per_16B
=
128
//
dtype
.
bits
# Every row below corresponds to 16 banks
col_idx_16B
=
col_idx
//
elem_per_16B
# 0 1 2 3 ==> 0 1 2 3
col_idx_in_16B
=
col_idx
%
elem_per_16B
# 0 1 2 3 ==> 0 1 2 3
new_col_idx_16B
=
col_idx_16B
^
(
row_idx
%
(
swizzle_bytes
//
16
))
# 0 1 2 3 ==> 1 0 3 2
return
row_idx
,
ana
.
simplify
(
new_col_idx_16B
*
elem_per_16B
+
col_idx_in_16B
)
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub
=
row_idx
%
bank_elems
# Interleave elems per byte
interleave_elems
=
32
//
dtype
.
bits
new_col_idx_outer
=
col_idx_outer
^
(
row_idx_sub
//
interleave_elems
)
assert
(
new_col_idx_outer
is
not
None
),
f
"Unsupported dtype
{
dtype
}
with
{
coalescent_bits
}
bits"
return
row_idx
,
ana
.
simplify
(
new_col_idx_outer
*
bank_elems
+
col_idx_inner
)
def
make_mma_swizzle_layout
(
shared_buf
,
is_smooth
:
bool
=
False
):
def
make_mma_swizzle_layout
(
shared_buf
,
is_smooth
:
bool
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment